feat(compiler): introduce concrete-protocol

This commit:
 + Adds support for a protocol which enables inter-op between concrete,
   tfhe-rs and potentially other contributors to the fhe ecosystem.
 + Gets rid of hand-made serialization in the compiler, and
   client/server libs.
 + Refactors client/server libs to allow more pre/post processing of
   circuit inputs/outputs.

The protocol is supported by a definition in the shape of a capnp file,
which defines different types of objects among which:
 + ProgramInfo object, which is a precise description of a set of fhe
   circuit coming from the same compilation (understand function type
   information), and the associated key set.
 + *Key objects, which represent secret/public keys used to
   encrypt/execute fhe circuits.
 + Value object, which represent values that can be transferred between
   client and server to support calls to fhe circuits.

The hand-rolled serialization that was previously used is completely
dropped in favor of capnp in the whole codebase.

The client/server libs, are refactored to introduce a modular design for
pre-post processing. Reading the ProgramInfo file associated with a
compilation, the client and server libs assemble a pipeline of
transformers (functions) for pre and post processing of values coming in
and out of a circuit. This design properly decouples various aspects of
the processing, and allows these capabilities to be safely extended.

In practice this commit includes the following:
 + Defines the specification in a concreteprotocol package
 + Integrate the compilation of this package as a compiler dependency
   via cmake
 + Modify the compiler to use the Encodings objects defined in the
   protocol
 + Modify the compiler to emit ProgramInfo files as compilation
   artifact, and gets rid of the bloated ClientParameters.
 + Introduces a new Common library containing the functionalities shared
   between the compiler and the client/server libs.
 + Introduces a functional pre-post processing pipeline to this common
   library
 + Modify the client/server libs to support loading ProgramInfo objects,
   and calling circuits using Value messages.
 + Drops support of JIT.
 + Drops support of C-api.
 + Drops support of Rust bindings.

Co-authored-by: Nikita Frolov <nf@mkmks.org>
This commit is contained in:
Alexandre Péré
2023-10-27 16:32:40 +02:00
committed by Alexandre Péré
parent 9139101cc3
commit e8ef48ffd8
207 changed files with 8601 additions and 16816 deletions

View File

@@ -48,5 +48,3 @@ _build/
# macOS
.DS_Store
compiler/lib/Bindings/Rust/target/

View File

@@ -1,5 +1,6 @@
# Build dirs
build*/
.cache/
*.mlir.script
*.lit_test_times.txt

View File

@@ -3,13 +3,14 @@ cmake_minimum_required(VERSION 3.17)
project(concretecompiler LANGUAGES C CXX)
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set_property(GLOBAL PROPERTY GLOBAL_DEPENDS_DEBUG_MODE 0)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
# Needed on linux with clang 15 and on MacOS because cxx emits dollars in the optimizer C++ API
add_definitions("-Wno-dollar-in-identifier-extension")
add_definitions("-Wno-c++98-compat-extra-semi")
add_definitions("-Wall ")
add_definitions("-Werror ")
add_definitions("-Wfatal-errors")
@@ -66,6 +67,19 @@ set(CONCRETELANG_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
# -------------------------------------------------------------------------------
include_directories(${PROJECT_SOURCE_DIR}/../../../tools/parameter-curves/concrete-security-curves-cpp/include)
# -------------------------------------------------------------------------------
# Concrete Protocol
# -------------------------------------------------------------------------------
set(CONCRETE_PROTOCOL_DIR "${PROJECT_SOURCE_DIR}/../../../tools/concrete-protocol")
add_subdirectory(${CONCRETE_PROTOCOL_DIR} concrete-protocol)
get_target_property(CONCRETE_PROTOCOL_GEN_DIR concrete-protocol BINARY_DIR)
set(CONCRETE_PROTOCOL_CAPNP_SRC "${CONCRETE_PROTOCOL_GEN_DIR}/capnp_src_dir/c++/src")
include_directories(${CONCRETE_PROTOCOL_GEN_DIR})
include_directories(${CONCRETE_PROTOCOL_CAPNP_SRC})
add_dependencies(mlir-headers concrete-protocol)
install(TARGETS concrete-protocol EXPORT concrete-protocol)
install(EXPORT concrete-protocol DESTINATION "./")
# -------------------------------------------------------------------------------
# Concrete Optimizer
# -------------------------------------------------------------------------------

View File

@@ -110,7 +110,7 @@ else
PYTHON_TESTS_MARKER="not parallel"
endif
all: concretecompiler python-bindings build-tests build-benchmarks build-mlbench doc rust-bindings
all: concretecompiler python-bindings build-tests build-benchmarks build-mlbench doc
# HPX #####################################################
@@ -174,14 +174,6 @@ python-bindings: build-initialized
cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules
cmake --build $(BUILD_DIR) --target ConcretelangPythonModules
rust-bindings: install
cd lib/Bindings/Rust && \
CONCRETE_COMPILER_INSTALL_DIR=$(INSTALL_PATH) \
cargo build --release
CAPI:
cmake --build $(BUILD_DIR) --target CONCRETELANGCAPIFHE CONCRETELANGCAPIFHELINALG CONCRETELANGCAPISupport
clientlib: build-initialized
cmake --build $(BUILD_DIR) --target ConcretelangClientLib
@@ -249,14 +241,6 @@ run-python-tests: python-bindings concretecompiler
test-compiler-file-output: concretecompiler
pytest -vs tests/test_compiler_file_output
## rust-tests
run-rust-tests: rust-bindings
cd lib/Bindings/Rust && \
CONCRETE_COMPILER_INSTALL_DIR=$(INSTALL_PATH) \
LD_LIBRARY_PATH=$(INSTALL_PATH)/lib \
cargo test --release
## end-to-end-tests
build-end-to-end-jit-chunked-int: build-initialized
@@ -307,7 +291,7 @@ run-end-to-end-tests: $(GTEST_PARALLEL_PY) build-end-to-end-tests generate-cpu-t
$(foreach optimizer_strategy,$(OPTIMIZATION_STRATEGY_TO_TEST), $(foreach security,$(SECURITY_TO_TEST), \
$(GTEST_PARALLEL_CMD) $(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_test \
$(GTEST_PARALLEL_SEPARATOR) --backend=cpu --security-level=$(security) \
--optimizer-strategy=$(optimizer_strategy) --jit $(FIXTURE_CPU_DIR)/*.yaml || exit $$?;))
--optimizer-strategy=$(optimizer_strategy) $(FIXTURE_CPU_DIR)/*.yaml || exit $$?;))
### end-to-end-tests GPU
@@ -327,7 +311,7 @@ generate-gpu-tests: $(FIXTURE_GPU_DIR) $(FIXTURE_GPU_DIR)/end_to_end_apply_looku
run-end-to-end-tests-gpu: build-end-to-end-test generate-gpu-tests
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_test \
--backend=gpu --library /tmp/concrete_compiler/gpu_tests/ \
--backend=gpu \
$(FIXTURE_GPU_DIR)/*.yaml
## end-to-end-dataflow-tests
@@ -477,12 +461,6 @@ python-format:
python-lint:
pylint --rcfile=../pylintrc lib/Bindings/Python/concrete/compiler
check-rust-format:
cd lib/Bindings/Rust && cargo fmt --check
rust-format:
cd lib/Bindings/Rust && cargo fmt
# libraries we want to have in the installation that aren't already a deps of other targets
install-deps:
cmake --build $(BUILD_DIR) --target MLIRCAPIRegisterEverything
@@ -529,7 +507,7 @@ darwin-python-package:
python-package: python-bindings $(OS)-python-package
@echo The python package is: $(BUILD_DIR)/wheels/*.whl
install: concretecompiler CAPI install-deps
install: concretecompiler install-deps
$(info Install prefix set to $(INSTALL_PREFIX))
$(info Installing under $(INSTALL_PATH))
mkdir -p $(INSTALL_PATH)/include

View File

@@ -1,48 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_C_DIALECT_FHE_H
#define CONCRETELANG_C_DIALECT_FHE_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
/// \brief structure to return an MlirType or report that there was an error
/// during type creation.
typedef struct {
MlirType type;
bool isError;
} MlirTypeOrError;
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHE, fhe);
/// Creates an encrypted integer type of `width` bits
MLIR_CAPI_EXPORTED MlirTypeOrError
fheEncryptedIntegerTypeGetChecked(MlirContext context, unsigned width);
/// If the type is an EncryptedInteger
MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedIntegerType(MlirType);
/// Creates an encrypted signed integer type of `width` bits
MLIR_CAPI_EXPORTED MlirTypeOrError
fheEncryptedSignedIntegerTypeGetChecked(MlirContext context, unsigned width);
/// If the type is an EncryptedSignedInteger
MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedSignedIntegerType(MlirType);
/// \brief Get bitwidth of the encrypted integer type.
///
/// \return bitwidth of the encrypted integer or 0 if it's not an encrypted
/// integer
MLIR_CAPI_EXPORTED unsigned fheTypeIntegerWidthGet(MlirType);
#ifdef __cplusplus
}
#endif
#endif // CONCRETELANG_C_DIALECT_FHE_H

View File

@@ -1,21 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_C_DIALECT_FHELINALG_H
#define CONCRETELANG_C_DIALECT_FHELINALG_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg);
#ifdef __cplusplus
}
#endif
#endif // CONCRETELANG_C_DIALECT_FHELINALG_H

View File

@@ -1,21 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_C_DIALECT_TRACING_H
#define CONCRETELANG_C_DIALECT_TRACING_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TRACING, tracing);
#ifdef __cplusplus
}
#endif
#endif // CONCRETELANG_C_DIALECT_TRACING_H

View File

@@ -1,406 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H
#define CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
/// The CAPI should be really careful about memory allocation. Every pointer
/// returned should points to a new buffer allocated for the purpose of the
/// CAPI, and should have a respective destructor function.
/// Opaque type declarations. Inspired from
/// llvm-project/mlir/include/mlir-c/IR.h
///
/// Adds an error pointer to an allocated buffer holding the error message if
/// any.
#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
const char *error; \
}; \
typedef struct name name
DEFINE_C_API_STRUCT(CompilerEngine, void);
DEFINE_C_API_STRUCT(CompilationContext, void);
DEFINE_C_API_STRUCT(CompilationResult, void);
DEFINE_C_API_STRUCT(Library, void);
DEFINE_C_API_STRUCT(LibraryCompilationResult, void);
DEFINE_C_API_STRUCT(LibrarySupport, void);
DEFINE_C_API_STRUCT(CompilationOptions, void);
DEFINE_C_API_STRUCT(OptimizerConfig, void);
DEFINE_C_API_STRUCT(ServerLambda, void);
DEFINE_C_API_STRUCT(Encoding, void);
DEFINE_C_API_STRUCT(EncryptionGate, void);
DEFINE_C_API_STRUCT(CircuitGate, void);
DEFINE_C_API_STRUCT(ClientParameters, void);
DEFINE_C_API_STRUCT(KeySet, void);
DEFINE_C_API_STRUCT(KeySetCache, void);
DEFINE_C_API_STRUCT(EvaluationKeys, void);
DEFINE_C_API_STRUCT(LambdaArgument, void);
DEFINE_C_API_STRUCT(PublicArguments, void);
DEFINE_C_API_STRUCT(PublicResult, void);
DEFINE_C_API_STRUCT(CompilationFeedback, void);
#undef DEFINE_C_API_STRUCT
/// NULL Pointer checkers. Generate functions to check if the struct contains a
/// null pointer.
#define DEFINE_NULL_PTR_CHECKER(funcname, storage) \
bool funcname(storage s) { return s.ptr == NULL; }
DEFINE_NULL_PTR_CHECKER(compilerEngineIsNull, CompilerEngine)
DEFINE_NULL_PTR_CHECKER(compilationContextIsNull, CompilationContext)
DEFINE_NULL_PTR_CHECKER(compilationResultIsNull, CompilationResult)
DEFINE_NULL_PTR_CHECKER(libraryIsNull, Library)
DEFINE_NULL_PTR_CHECKER(libraryCompilationResultIsNull,
LibraryCompilationResult)
DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport)
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions)
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig)
DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda)
DEFINE_NULL_PTR_CHECKER(circuitGateIsNull, CircuitGate)
DEFINE_NULL_PTR_CHECKER(encodingIsNull, Encoding)
DEFINE_NULL_PTR_CHECKER(encryptionGateIsNull, EncryptionGate)
DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters)
DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet)
DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache)
DEFINE_NULL_PTR_CHECKER(evaluationKeysIsNull, EvaluationKeys)
DEFINE_NULL_PTR_CHECKER(lambdaArgumentIsNull, LambdaArgument)
DEFINE_NULL_PTR_CHECKER(publicArgumentsIsNull, PublicArguments)
DEFINE_NULL_PTR_CHECKER(publicResultIsNull, PublicResult)
DEFINE_NULL_PTR_CHECKER(compilationFeedbackIsNull, CompilationFeedback)
#undef DEFINE_NULL_PTR_CHECKER
/// Each struct has a creator function that allocates memory for the underlying
/// Cpp object referenced, and a destroy function that does free this allocated
/// memory.
/// ********** Utilities *******************************************************
/// Destroy string references created by the compiler.
///
/// This is not supposed to destroy any string ref, but only the ones we have
/// allocated memory for and know how to free.
MLIR_CAPI_EXPORTED void mlirStringRefDestroy(MlirStringRef str);
MLIR_CAPI_EXPORTED bool mlirStringRefIsNull(MlirStringRef str) {
return str.data == NULL;
}
/// ********** BufferRef CAPI **************************************************
/// A struct for binary buffers.
///
/// Contraty to MlirStringRef, it doesn't assume the pointer point to a null
/// terminated string and the data should be considered as is in binary form.
/// Useful for serialized objects.
typedef struct BufferRef {
const char *data;
size_t length;
const char *error;
} BufferRef;
MLIR_CAPI_EXPORTED void bufferRefDestroy(BufferRef buffer);
MLIR_CAPI_EXPORTED bool bufferRefIsNull(BufferRef buffer) {
return buffer.data == NULL;
}
MLIR_CAPI_EXPORTED BufferRef bufferRefCreate(const char *buffer, size_t length);
/// ********** CompilationTarget CAPI ******************************************
enum CompilationTarget {
ROUND_TRIP,
FHE,
TFHE,
PARAMETRIZED_TFHE,
NORMALIZED_TFHE,
BATCHED_TFHE,
CONCRETE,
STD,
LLVM,
LLVM_IR,
OPTIMIZED_LLVM_IR,
LIBRARY
};
typedef enum CompilationTarget CompilationTarget;
/// ********** CompilationOptions CAPI *****************************************
MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate(
MlirStringRef funcName, bool autoParallelize, bool batchTFHEOps,
bool dataflowParallelize, bool emitGPUOps, bool loopParallelize,
bool optimizeTFHE, OptimizerConfig optimizerConfig, bool verifyDiagnostics);
MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreateDefault();
MLIR_CAPI_EXPORTED void compilationOptionsDestroy(CompilationOptions options);
/// ********** OptimizerConfig CAPI ********************************************
MLIR_CAPI_EXPORTED OptimizerConfig
optimizerConfigCreate(bool display, double fallback_log_norm_woppbs,
double global_p_error, double p_error, uint64_t security,
bool strategy_v0, bool use_gpu_constraints,
uint32_t ciphertext_modulus_log, uint32_t fft_precision);
MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreateDefault();
MLIR_CAPI_EXPORTED void optimizerConfigDestroy(OptimizerConfig config);
/// ********** CompilerEngine CAPI *********************************************
MLIR_CAPI_EXPORTED CompilerEngine compilerEngineCreate();
MLIR_CAPI_EXPORTED void compilerEngineDestroy(CompilerEngine engine);
MLIR_CAPI_EXPORTED CompilationResult compilerEngineCompile(
CompilerEngine engine, MlirStringRef module, CompilationTarget target);
MLIR_CAPI_EXPORTED void
compilerEngineCompileSetOptions(CompilerEngine engine,
CompilationOptions options);
/// ********** CompilationResult CAPI ******************************************
/// Get a string reference holding the textual representation of the compiled
/// module. The returned `MlirStringRef` should be destroyed using
/// `mlirStringRefDestroy` to free memory.
MLIR_CAPI_EXPORTED MlirStringRef
compilationResultGetModuleString(CompilationResult result);
MLIR_CAPI_EXPORTED void compilationResultDestroy(CompilationResult result);
/// ********** Library CAPI ****************************************************
MLIR_CAPI_EXPORTED Library libraryCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath,
bool cleanUp);
MLIR_CAPI_EXPORTED void libraryDestroy(Library lib);
/// ********** LibraryCompilationResult CAPI ***********************************
MLIR_CAPI_EXPORTED void
libraryCompilationResultDestroy(LibraryCompilationResult result);
/// ********** LibrarySupport CAPI *********************************************
MLIR_CAPI_EXPORTED LibrarySupport
librarySupportCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool generateSharedLib,
bool generateStaticLib, bool generateClientParameters,
bool generateCompilationFeedback, bool generateCppHeader);
MLIR_CAPI_EXPORTED LibrarySupport librarySupportCreateDefault(
MlirStringRef outputDirPath, MlirStringRef runtimeLibraryPath) {
return librarySupportCreate(outputDirPath, runtimeLibraryPath, true, true,
true, true, true);
}
MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile(
LibrarySupport support, MlirStringRef module, CompilationOptions options);
MLIR_CAPI_EXPORTED ServerLambda librarySupportLoadServerLambda(
LibrarySupport support, LibraryCompilationResult result);
MLIR_CAPI_EXPORTED ClientParameters librarySupportLoadClientParameters(
LibrarySupport support, LibraryCompilationResult result);
MLIR_CAPI_EXPORTED LibraryCompilationResult
librarySupportLoadCompilationResult(LibrarySupport support);
MLIR_CAPI_EXPORTED CompilationFeedback librarySupportLoadCompilationFeedback(
LibrarySupport support, LibraryCompilationResult result);
MLIR_CAPI_EXPORTED PublicResult
librarySupportServerCall(LibrarySupport support, ServerLambda server,
PublicArguments args, EvaluationKeys evalKeys);
MLIR_CAPI_EXPORTED MlirStringRef
librarySupportGetSharedLibPath(LibrarySupport support);
MLIR_CAPI_EXPORTED MlirStringRef
librarySupportGetClientParametersPath(LibrarySupport support);
MLIR_CAPI_EXPORTED void librarySupportDestroy(LibrarySupport support);
/// ********** ServerLamda CAPI ************************************************
MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server);
/// ********** ClientParameters CAPI *******************************************
MLIR_CAPI_EXPORTED BufferRef clientParametersSerialize(ClientParameters params);
MLIR_CAPI_EXPORTED ClientParameters
clientParametersUnserialize(BufferRef buffer);
MLIR_CAPI_EXPORTED ClientParameters
clientParametersCopy(ClientParameters params);
MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params);
/// Returns the number of output circuit gates
MLIR_CAPI_EXPORTED size_t clientParametersOutputsSize(ClientParameters params);
/// Returns the number of input circuit gates
MLIR_CAPI_EXPORTED size_t clientParametersInputsSize(ClientParameters params);
/// Returns the output circuit gate corresponding to the index
///
/// - `index` must be valid.
MLIR_CAPI_EXPORTED CircuitGate
clientParametersOutputCircuitGate(ClientParameters params, size_t index);
/// Returns the input circuit gate corresponding to the index
///
/// - `index` must be valid.
MLIR_CAPI_EXPORTED CircuitGate
clientParametersInputCircuitGate(ClientParameters params, size_t index);
/// Returns the EncryptionGate of the circuit gate.
///
/// - The returned gate will be null if the gate does not represent encrypted
/// data
MLIR_CAPI_EXPORTED EncryptionGate
circuitGateEncryptionGate(CircuitGate circuit_gate);
/// Returns the variance of the encryption gate
MLIR_CAPI_EXPORTED double
encryptionGateVariance(EncryptionGate encryption_gate);
/// Returns the Encoding of the encryption gate.
MLIR_CAPI_EXPORTED Encoding
encryptionGateEncoding(EncryptionGate encryption_gate);
/// Returns the precision (bit width) of the encoding
MLIR_CAPI_EXPORTED uint64_t encodingPrecision(Encoding encoding);
MLIR_CAPI_EXPORTED void circuitGateDestroy(CircuitGate gate);
MLIR_CAPI_EXPORTED void encryptionGateDestroy(EncryptionGate gate);
MLIR_CAPI_EXPORTED void encodingDestroy(Encoding encoding);
/// ********** KeySet CAPI *****************************************************
MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params,
uint64_t seed_msb, uint64_t seed_lsb);
MLIR_CAPI_EXPORTED EvaluationKeys keySetGetEvaluationKeys(KeySet keySet);
MLIR_CAPI_EXPORTED void keySetDestroy(KeySet keySet);
/// ********** KeySetCache CAPI ************************************************
MLIR_CAPI_EXPORTED KeySetCache keySetCacheCreate(MlirStringRef cachePath);
MLIR_CAPI_EXPORTED KeySet
keySetCacheLoadOrGenerateKeySet(KeySetCache cache, ClientParameters params,
uint64_t seed_msb, uint64_t seed_lsb);
MLIR_CAPI_EXPORTED void keySetCacheDestroy(KeySetCache keySetCache);
/// ********** EvaluationKeys CAPI *********************************************
MLIR_CAPI_EXPORTED BufferRef evaluationKeysSerialize(EvaluationKeys keys);
MLIR_CAPI_EXPORTED EvaluationKeys evaluationKeysUnserialize(BufferRef buffer);
MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
/// ********** LambdaArgument CAPI *********************************************
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(
const uint8_t *data, const int64_t *dims, size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(
const uint16_t *data, const int64_t *dims, size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(
const uint32_t *data, const int64_t *dims, size_t rank);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(
const uint64_t *data, const int64_t *dims, size_t rank);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg,
uint64_t *buffer);
MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED int64_t
lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg,
int64_t *buffer);
MLIR_CAPI_EXPORTED PublicArguments
lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber,
ClientParameters params, KeySet keySet);
MLIR_CAPI_EXPORTED void lambdaArgumentDestroy(LambdaArgument lambdaArg);
/// ********** PublicArguments CAPI ********************************************
MLIR_CAPI_EXPORTED BufferRef publicArgumentsSerialize(PublicArguments args);
MLIR_CAPI_EXPORTED PublicArguments
publicArgumentsUnserialize(BufferRef buffer, ClientParameters params);
MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs);
/// ********** PublicResult CAPI ***********************************************
MLIR_CAPI_EXPORTED LambdaArgument publicResultDecrypt(PublicResult publicResult,
KeySet keySet);
MLIR_CAPI_EXPORTED BufferRef publicResultSerialize(PublicResult result);
MLIR_CAPI_EXPORTED PublicResult
publicResultUnserialize(BufferRef buffer, ClientParameters params);
MLIR_CAPI_EXPORTED void publicResultDestroy(PublicResult publicResult);
/// ********** CompilationFeedback CAPI ****************************************
MLIR_CAPI_EXPORTED double
compilationFeedbackGetComplexity(CompilationFeedback feedback);
MLIR_CAPI_EXPORTED double
compilationFeedbackGetPError(CompilationFeedback feedback);
MLIR_CAPI_EXPORTED double
compilationFeedbackGetGlobalPError(CompilationFeedback feedback);
MLIR_CAPI_EXPORTED uint64_t
compilationFeedbackGetTotalSecretKeysSize(CompilationFeedback feedback);
MLIR_CAPI_EXPORTED uint64_t
compilationFeedbackGetTotalBootstrapKeysSize(CompilationFeedback feedback);
MLIR_CAPI_EXPORTED uint64_t
compilationFeedbackGetTotalKeyswitchKeysSize(CompilationFeedback feedback);
MLIR_CAPI_EXPORTED uint64_t
compilationFeedbackGetTotalInputsSize(CompilationFeedback feedback);
MLIR_CAPI_EXPORTED uint64_t
compilationFeedbackGetTotalOutputsSize(CompilationFeedback feedback);
MLIR_CAPI_EXPORTED void
compilationFeedbackDestroy(CompilationFeedback feedback);
#ifdef __cplusplus
}
#endif
#endif // CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H

View File

@@ -6,10 +6,8 @@
#ifndef CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
#define CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
#include "concretelang/Common/Compat.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/JITSupport.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/LibrarySupport.h"
#include "mlir-c/IR.h"
/// MLIR_CAPI_EXPORTED is used here throughout the API, because of the way the
@@ -29,42 +27,6 @@ struct executionArguments {
};
typedef struct executionArguments executionArguments;
// JIT Support bindings ///////////////////////////////////////////////////////
struct JITSupport_Py {
mlir::concretelang::JITSupport support;
};
typedef struct JITSupport_Py JITSupport_Py;
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile(JITSupport_Py support, const char *module,
mlir::concretelang::CompilationOptions options);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile_module(
JITSupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx);
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITSupport_Py support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
jit_load_compilation_feedback(JITSupport_Py support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
jit_load_server_lambda(JITSupport_Py support,
mlir::concretelang::JitCompilationResult &);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda,
concretelang::clientlib::PublicArguments &args,
concretelang::clientlib::EvaluationKeys &evaluationKeys);
// Library Support bindings ///////////////////////////////////////////////////
struct LibrarySupport_Py {
@@ -78,17 +40,17 @@ library_support(const char *outputPath, const char *runtimeLibraryPath,
bool generateClientParameters, bool generateCompilationFeedback,
bool generateCppHeader);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibrarySupport_Py support, const char *module,
mlir::concretelang::CompilationOptions options);
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile_module(
LibrarySupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx);
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibrarySupport_Py support, const char *module,
mlir::concretelang::CompilationOptions options);
MLIR_CAPI_EXPORTED concretelang::clientlib::ClientParameters
library_load_client_parameters(LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &);
@@ -98,7 +60,8 @@ library_load_compilation_feedback(
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
library_load_server_lambda(LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &);
mlir::concretelang::LibraryCompilationResult &,
bool useSimulation);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
library_server_call(LibrarySupport_Py support,
@@ -115,7 +78,7 @@ MLIR_CAPI_EXPORTED std::string
library_get_shared_lib_path(LibrarySupport_Py support);
MLIR_CAPI_EXPORTED std::string
library_get_client_parameters_path(LibrarySupport_Py support);
library_get_program_info_path(LibrarySupport_Py support);
// Client Support bindings ///////////////////////////////////////////////////
@@ -130,28 +93,30 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args);
MLIR_CAPI_EXPORTED lambdaArgument
decrypt_result(concretelang::clientlib::KeySet &keySet,
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::PublicResult &publicResult);
// Serialization ////////////////////////////////////////////////////////////
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
MLIR_CAPI_EXPORTED concretelang::clientlib::ClientParameters
clientParametersUnserialize(const std::string &json);
MLIR_CAPI_EXPORTED std::string
clientParametersSerialize(mlir::concretelang::ClientParameters &params);
clientParametersSerialize(concretelang::clientlib::ClientParameters &params);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
publicArgumentsUnserialize(
mlir::concretelang::ClientParameters &clientParameters,
concretelang::clientlib::ClientParameters &clientParameters,
const std::string &buffer);
MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize(
concretelang::clientlib::PublicArguments &publicArguments);
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
const std::string &buffer);
publicResultUnserialize(
concretelang::clientlib::ClientParameters &clientParameters,
const std::string &buffer);
MLIR_CAPI_EXPORTED std::string
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult);
@@ -174,6 +139,22 @@ valueUnserialize(const std::string &buffer);
MLIR_CAPI_EXPORTED std::string
valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value);
MLIR_CAPI_EXPORTED concretelang::clientlib::ValueExporter createValueExporter(
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::ClientParameters &clientParameters);
MLIR_CAPI_EXPORTED concretelang::clientlib::SimulatedValueExporter
createSimulatedValueExporter(
concretelang::clientlib::ClientParameters &clientParameters);
MLIR_CAPI_EXPORTED concretelang::clientlib::ValueDecrypter createValueDecrypter(
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::ClientParameters &clientParameters);
MLIR_CAPI_EXPORTED concretelang::clientlib::SimulatedValueDecrypter
createSimulatedValueDecrypter(
concretelang::clientlib::ClientParameters &clientParameters);
/// Parse then print a textual representation of an MLIR module
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);

View File

@@ -1,137 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
#define CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
#include <cassert>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
using scalar_in = uint8_t;
using scalar_out = uint64_t;
using tensor1_in = std::vector<scalar_in>;
using tensor2_in = std::vector<std::vector<scalar_in>>;
using tensor3_in = std::vector<std::vector<std::vector<scalar_in>>>;
using tensor1_out = std::vector<scalar_out>;
using tensor2_out = std::vector<std::vector<scalar_out>>;
using tensor3_out = std::vector<std::vector<std::vector<scalar_out>>>;
/// Low-level class to create the client side view of a FHE function.
class ClientLambda {
public:
virtual ~ClientLambda() = default;
/// Construct a ClientLambda from a ClientParameter file.
static outcome::checked<ClientLambda, StringError> load(std::string funcName,
std::string jsonPath);
/// Generate or get from cache a KeySet suitable for this ClientLambda
outcome::checked<std::unique_ptr<KeySet>, StringError>
keySet(std::shared_ptr<KeySetCache> optionalCache, uint64_t seed_msb,
uint64_t seed_lsb);
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
decryptReturnedValues(KeySet &keySet, PublicResult &result);
outcome::checked<decrypted_scalar_t, StringError>
decryptReturnedScalar(KeySet &keySet, PublicResult &result);
outcome::checked<decrypted_tensor_1_t, StringError>
decryptReturnedTensor1(KeySet &keySet, PublicResult &result);
outcome::checked<decrypted_tensor_2_t, StringError>
decryptReturnedTensor2(KeySet &keySet, PublicResult &result);
outcome::checked<decrypted_tensor_3_t, StringError>
decryptReturnedTensor3(KeySet &keySet, PublicResult &result);
public:
ClientParameters clientParameters;
};
template <typename Result>
outcome::checked<Result, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
PublicResult &result);
template <typename Result, typename... Args>
class TypedClientLambda : public ClientLambda {
public:
static outcome::checked<TypedClientLambda<Result, Args...>, StringError>
load(std::string funcName, std::string jsonPath) {
OUTCOME_TRY(auto lambda, ClientLambda::load(funcName, jsonPath));
return TypedClientLambda(lambda);
}
/// Emit a call on this lambda to a binary ostream.
/// The ostream is responsible for transporting the call to a
/// ServerLambda::real_call_write function. ostream must be in binary mode
/// std::ios_base::openmode::binary
outcome::checked<void, StringError>
serializeCall(Args... args, KeySet &keySet, std::ostream &ostream) {
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
return publicArguments->serialize(ostream);
}
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
publicArguments(Args... args, KeySet &keySet) {
OUTCOME_TRY(
auto clientArguments,
EncryptedArguments::create(/*simulation*/ false, keySet, args...));
return clientArguments->exportPublicArguments(clientParameters);
}
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,
PublicResult &result) {
return topLevelDecryptResult<Result>((*this), keySet, result);
}
TypedClientLambda(ClientLambda &lambda) : ClientLambda(lambda) {
// TODO: check parameter types
// TODO: add static check on types vs lambda inputs/outpus
}
protected:
// Workaround, gcc 6 does not support partial template specialisation in class
template <typename Result_>
friend outcome::checked<Result_, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
PublicResult &result);
};
template <>
outcome::checked<decrypted_scalar_t, StringError>
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
PublicResult &result);
template <>
outcome::checked<decrypted_tensor_1_t, StringError>
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
KeySet &keySet,
PublicResult &result);
template <>
outcome::checked<decrypted_tensor_2_t, StringError>
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
KeySet &keySet,
PublicResult &result);
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -0,0 +1,89 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_REFACTORED_H
#define CONCRETELANG_CLIENTLIB_REFACTORED_H
#include <cassert>
#include <cstdint>
#include <cstring>
#include <optional>
#include <string>
#include <variant>
#include "boost/outcome.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Keysets.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Common/Transformers.h"
#include "concretelang/Common/Values.h"
using concretelang::error::Result;
using concretelang::keysets::ClientKeyset;
using concretelang::transformers::InputTransformer;
using concretelang::transformers::OutputTransformer;
using concretelang::transformers::TransformerFactory;
using concretelang::values::TransportValue;
using concretelang::values::Value;
namespace concretelang {
namespace clientlib {
class ClientCircuit {
public:
static Result<ClientCircuit>
create(const Message<concreteprotocol::CircuitInfo> &info,
const ClientKeyset &keyset, std::shared_ptr<CSPRNG> csprng,
bool useSimulation = false);
Result<TransportValue> prepareInput(Value arg, size_t pos);
Result<Value> processOutput(TransportValue result, size_t pos);
std::string getName();
const Message<concreteprotocol::CircuitInfo> &getCircuitInfo();
private:
ClientCircuit() = delete;
ClientCircuit(const Message<concreteprotocol::CircuitInfo> &circuitInfo,
std::vector<InputTransformer> inputTransformers,
std::vector<OutputTransformer> outputTransformers)
: circuitInfo(circuitInfo), inputTransformers(inputTransformers),
outputTransformers(outputTransformers){};
private:
Message<concreteprotocol::CircuitInfo> circuitInfo;
std::vector<InputTransformer> inputTransformers;
std::vector<OutputTransformer> outputTransformers;
};
/// Contains all the context to generate inputs for a server call by the
/// server lib.
class ClientProgram {
public:
/// Generates a fresh client program with fresh keyset on the first use.
static Result<ClientProgram>
create(const Message<concreteprotocol::ProgramInfo> &info,
const ClientKeyset &keyset, std::shared_ptr<CSPRNG> csprng,
bool useSimulation = false);
/// Returns a reference to the named client circuit if it exists.
Result<ClientCircuit> getClientCircuit(std::string circuitName);
private:
ClientProgram() = default;
private:
std::vector<ClientCircuit> circuits;
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,350 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
#include <map>
#include <optional>
#include <string>
#include <vector>
#include "boost/outcome.h"
#include "concretelang/Common/Error.h"
#include <llvm/Support/JSON.h>
namespace concretelang {
inline size_t bitWidthAsWord(size_t exactBitWidth) {
if (exactBitWidth <= 8)
return 8;
if (exactBitWidth <= 16)
return 16;
if (exactBitWidth <= 32)
return 32;
if (exactBitWidth <= 64)
return 64;
assert(false && "Bit witdh > 64 not supported");
}
namespace clientlib {
using concretelang::error::StringError;
const uint64_t SMALL_KEY = 1;
const uint64_t BIG_KEY = 0;
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";
typedef uint64_t DecompositionLevelCount;
typedef uint64_t DecompositionBaseLog;
typedef uint64_t PolynomialSize;
typedef uint64_t Precision;
typedef double Variance;
typedef std::vector<int64_t> CRTDecomposition;
typedef uint64_t LweDimension;
typedef uint64_t GlweDimension;
typedef uint64_t LweSecretKeyID;
struct LweSecretKeyParam {
LweDimension dimension;
void hash(size_t &seed);
inline uint64_t lweDimension() { return dimension; }
inline uint64_t lweSize() { return dimension + 1; }
inline uint64_t byteSize() { return lweSize() * 8; }
};
static bool operator==(const LweSecretKeyParam &lhs,
const LweSecretKeyParam &rhs) {
return lhs.dimension == rhs.dimension;
}
typedef uint64_t BootstrapKeyID;
struct BootstrapKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
Variance variance;
PolynomialSize polynomialSize;
LweDimension inputLweDimension;
void hash(size_t &seed);
uint64_t byteSize(uint64_t inputLweSize, uint64_t outputLweSize) {
return inputLweSize * level * (glweDimension + 1) * (glweDimension + 1) *
outputLweSize * 8;
}
};
static inline bool operator==(const BootstrapKeyParam &lhs,
const BootstrapKeyParam &rhs) {
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance;
}
typedef uint64_t KeyswitchKeyID;
struct KeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
Variance variance;
void hash(size_t &seed);
size_t byteSize(size_t inputLweSize, size_t outputLweSize) {
return level * inputLweSize * outputLweSize * 8;
}
};
static inline bool operator==(const KeyswitchKeyParam &lhs,
const KeyswitchKeyParam &rhs) {
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
lhs.variance == rhs.variance;
}
typedef uint64_t PackingKeyswitchKeyID;
struct PackingKeyswitchKeyParam {
LweSecretKeyID inputSecretKeyID;
LweSecretKeyID outputSecretKeyID;
DecompositionLevelCount level;
DecompositionBaseLog baseLog;
GlweDimension glweDimension;
PolynomialSize polynomialSize;
LweDimension inputLweDimension;
Variance variance;
void hash(size_t &seed);
};
static inline bool operator==(const PackingKeyswitchKeyParam &lhs,
const PackingKeyswitchKeyParam &rhs) {
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
lhs.glweDimension == rhs.glweDimension &&
lhs.polynomialSize == rhs.polynomialSize &&
lhs.variance == lhs.variance &&
lhs.inputLweDimension == rhs.inputLweDimension;
}
struct Encoding {
Precision precision;
CRTDecomposition crt;
bool isSigned;
};
static inline bool operator==(const Encoding &lhs, const Encoding &rhs) {
return lhs.precision == rhs.precision && lhs.isSigned == rhs.isSigned;
}
struct EncryptionGate {
LweSecretKeyID secretKeyID;
Variance variance;
Encoding encoding;
};
static inline bool operator==(const EncryptionGate &lhs,
const EncryptionGate &rhs) {
return lhs.secretKeyID == rhs.secretKeyID && lhs.variance == rhs.variance &&
lhs.encoding == rhs.encoding;
}
struct CircuitGateShape {
/// Width of the scalar value
uint64_t width;
/// Dimensions of the tensor, empty if scalar
std::vector<int64_t> dimensions;
/// Size of the buffer containing the tensor
uint64_t size;
// Indicated whether elements are signed
bool sign;
};
static inline bool operator==(const CircuitGateShape &lhs,
const CircuitGateShape &rhs) {
return lhs.width == rhs.width && lhs.dimensions == rhs.dimensions &&
lhs.size == rhs.size;
}
struct ChunkInfo {
/// total number of bits used for the chunk including the carry.
/// size should be at least width + 1
unsigned int size;
/// number of bits used for the chunk excluding the carry
unsigned int width;
};
static inline bool operator==(const ChunkInfo &lhs, const ChunkInfo &rhs) {
return lhs.width == rhs.width && lhs.size == rhs.size;
}
struct CircuitGate {
std::optional<EncryptionGate> encryption;
CircuitGateShape shape;
std::optional<ChunkInfo> chunkInfo;
bool isEncrypted() { return encryption.has_value(); }
/// byteSize returns the size in bytes for this gate.
size_t byteSize(std::vector<LweSecretKeyParam> secretKeys) {
auto width = shape.width;
auto numElts = shape.size == 0 ? 1 : shape.size;
if (isEncrypted()) {
assert(encryption->secretKeyID < secretKeys.size());
auto skParam = secretKeys[encryption->secretKeyID];
return 8 * skParam.lweSize() * numElts;
}
width = bitWidthAsWord(width) / 8;
return width * numElts;
}
};
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape &&
lhs.chunkInfo == rhs.chunkInfo;
}
struct ClientParameters {
std::vector<LweSecretKeyParam> secretKeys;
std::vector<BootstrapKeyParam> bootstrapKeys;
std::vector<KeyswitchKeyParam> keyswitchKeys;
std::vector<PackingKeyswitchKeyParam> packingKeyswitchKeys;
std::vector<CircuitGate> inputs;
std::vector<CircuitGate> outputs;
std::string functionName;
size_t hash();
static outcome::checked<std::vector<ClientParameters>, StringError>
load(std::string path);
static std::string getClientParametersPath(std::string path);
outcome::checked<CircuitGate, StringError> input(size_t pos) {
if (pos >= inputs.size()) {
return StringError("input gate ") << pos << " didn't exists";
}
return inputs[pos];
}
outcome::checked<CircuitGate, StringError> ouput(size_t pos) {
if (pos >= outputs.size()) {
return StringError("output gate ") << pos << " didn't exists";
}
return outputs[pos];
}
outcome::checked<LweSecretKeyParam, StringError>
lweSecretKeyParam(CircuitGate gate) {
if (!gate.encryption.has_value()) {
return StringError("gate is not encrypted");
}
assert(gate.encryption->secretKeyID < secretKeys.size());
auto secretKey = secretKeys[gate.encryption->secretKeyID];
return secretKey;
}
/// bufferSize returns the size of the whole buffer of a gate.
int64_t bufferSize(CircuitGate gate) {
if (!gate.encryption.has_value()) {
// Value is not encrypted just returns the tensor size
return gate.shape.size;
}
auto shapeSize = gate.shape.size == 0 ? 1 : gate.shape.size;
// Size of the ciphertext
return shapeSize * lweBufferSize(gate);
}
/// lweBufferSize returns the size of one ciphertext of a gate.
int64_t lweBufferSize(CircuitGate gate) {
assert(gate.encryption.has_value());
auto nbBlocks = gate.encryption->encoding.crt.size();
nbBlocks = nbBlocks == 0 ? 1 : nbBlocks;
auto param = lweSecretKeyParam(gate);
assert(param.has_value());
return param.value().lweSize() * nbBlocks;
}
/// bufferShape returns the shape of the tensor for the given gate. It returns
/// the shape used at low-level, i.e. contains the dimensions for ciphertexts
/// (if not in simulation).
std::vector<int64_t> bufferShape(CircuitGate gate, bool simulation = false) {
if (!gate.encryption.has_value()) {
// Value is not encrypted just returns the tensor shape
return gate.shape.dimensions;
}
auto lweSecreteKeyParam = lweSecretKeyParam(gate);
assert(lweSecreteKeyParam.has_value());
// Copy the shape
std::vector<int64_t> shape(gate.shape.dimensions);
auto crt = gate.encryption->encoding.crt;
// CRT case: Add one dimension equals to the number of blocks
if (!crt.empty()) {
shape.push_back(crt.size());
}
if (!simulation) {
// Add one dimension for the size of ciphertext(s)
shape.push_back(lweSecreteKeyParam.value().lweSize());
}
return shape;
}
};
static inline bool operator==(const ClientParameters &lhs,
const ClientParameters &rhs) {
return lhs.secretKeys == rhs.secretKeys &&
lhs.bootstrapKeys == rhs.bootstrapKeys &&
lhs.keyswitchKeys == rhs.keyswitchKeys && lhs.inputs == lhs.inputs &&
lhs.outputs == lhs.outputs;
}
llvm::json::Value toJSON(const LweSecretKeyParam &);
bool fromJSON(const llvm::json::Value, LweSecretKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const BootstrapKeyParam &);
bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const KeyswitchKeyParam &);
bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path);
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &);
bool fromJSON(const llvm::json::Value, PackingKeyswitchKeyParam &,
llvm::json::Path);
llvm::json::Value toJSON(const Encoding &);
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
llvm::json::Value toJSON(const EncryptionGate &);
bool fromJSON(const llvm::json::Value, EncryptionGate &, llvm::json::Path);
llvm::json::Value toJSON(const CircuitGateShape &);
bool fromJSON(const llvm::json::Value, CircuitGateShape &, llvm::json::Path);
llvm::json::Value toJSON(const CircuitGate &);
bool fromJSON(const llvm::json::Value, CircuitGate &, llvm::json::Path);
llvm::json::Value toJSON(const ClientParameters &);
bool fromJSON(const llvm::json::Value, ClientParameters &, llvm::json::Path);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
ClientParameters cp) {
return OS << llvm::formatv("{0:2}", toJSON(cp));
}
static inline llvm::raw_ostream &operator<<(llvm::raw_string_ostream &OS,
ClientParameters cp) {
return OS << llvm::formatv("{0:2}", toJSON(cp));
}
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,189 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
#define CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
#include <ostream>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/ClientLib/ValueExporter.h"
#include "concretelang/Common/BitsSize.h"
#include "concretelang/Common/Error.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
class PublicArguments;
/// Temporary object used to hold and encrypt parameters before calling a
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
/// Otherwise convert it to a PublicArguments and use
/// serializeCall(PublicArguments, KeySet).
class EncryptedArguments {
public:
EncryptedArguments(bool simulation = false) : simulation(simulation) {}
/// Encrypts args thanks the given KeySet and pack the encrypted arguments
/// to an EncryptedArguments
template <typename... Args>
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
create(bool simulation, KeySet &keySet, Args... args) {
auto encryptedArgs = std::make_unique<EncryptedArguments>(simulation);
OUTCOME_TRYV(encryptedArgs->pushArgs(keySet, args...));
return std::move(encryptedArgs);
}
template <typename ArgT>
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
create(bool simulation, KeySet &keySet, const llvm::ArrayRef<ArgT> args) {
auto encryptedArgs = EncryptedArguments::empty(simulation);
for (size_t i = 0; i < args.size(); i++) {
OUTCOME_TRYV(encryptedArgs->pushArg(args[i], keySet));
}
OUTCOME_TRYV(encryptedArgs->checkAllArgs(keySet));
return std::move(encryptedArgs);
}
static std::unique_ptr<EncryptedArguments> empty(bool simulation = false) {
return std::make_unique<EncryptedArguments>(simulation);
}
bool isSimulated() { return simulation; }
/// Export encrypted arguments as public arguments, reset the encrypted
/// arguments, i.e. move all buffers to the PublicArguments and reset the
/// positional counter.
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
exportPublicArguments(ClientParameters clientParameters);
/// Check that all arguments as been pushed.
// TODO: Remove public method here
outcome::checked<void, StringError> checkAllArgs(KeySet &keySet);
outcome::checked<void, StringError> checkAllArgs(ClientParameters &params);
public:
/// Add a uint64_t scalar argument.
outcome::checked<void, StringError> pushArg(uint64_t arg, KeySet &keySet) {
auto exporter = getExporter(keySet);
OUTCOME_TRY(auto value, exporter->exportValue(arg, values.size()));
values.push_back(std::move(value));
return outcome::success();
}
/// Add a vector-tensor argument.
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
KeySet &keySet) {
return pushArg((uint8_t *)arg.data(),
llvm::ArrayRef<int64_t>{(int64_t)arg.size()}, keySet);
}
/// Add a 1D tensor argument with data and size of the dimension.
template <typename T>
outcome::checked<void, StringError> pushArg(const T *data, int64_t dim1,
KeySet &keySet) {
return pushArg(std::vector<uint8_t>(data, data + dim1), keySet);
}
/// Add a 1D tensor argument.
template <size_t size>
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
KeySet &keySet) {
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size},
keySet);
}
/// Add a 2D tensor argument.
template <size_t size0, size_t size1>
outcome::checked<void, StringError>
pushArg(std::array<std::array<uint8_t, size1>, size0> arg, KeySet &keySet) {
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size0, size1},
keySet);
}
/// Add a 3D tensor argument.
template <size_t size0, size_t size1, size_t size2>
outcome::checked<void, StringError>
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
KeySet &keySet) {
return pushArg((uint8_t *)arg.data(),
llvm::ArrayRef<int64_t>{size0, size1, size2}, keySet);
}
// Generalize by computing shape by template recursion
/// Set a argument at the given pos as a 1D tensor of T.
template <typename T>
outcome::checked<void, StringError> pushArg(T *data, int64_t dim1,
KeySet &keySet) {
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1), keySet);
}
/// Set a argument at the given pos as a tensor of T.
template <typename T>
outcome::checked<void, StringError>
pushArg(T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
return pushArg(static_cast<const T *>(data), shape, keySet);
}
template <typename T>
outcome::checked<void, StringError>
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
auto exporter = getExporter(keySet);
OUTCOME_TRY(auto value, exporter->exportValue(data, shape, values.size()));
values.push_back(std::move(value));
return outcome::success();
}
/// Recursive case for scalars: extract first scalar argument from
/// parameter pack and forward rest
template <typename Arg0, typename... OtherArgs>
outcome::checked<void, StringError> pushArgs(KeySet &keySet, Arg0 arg0,
OtherArgs... others) {
OUTCOME_TRYV(pushArg(arg0, keySet));
return pushArgs(keySet, others...);
}
/// Recursive case for tensors: extract pointer and size from
/// parameter pack and forward rest
template <typename Arg0, typename... OtherArgs>
outcome::checked<void, StringError>
pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) {
OUTCOME_TRYV(pushArg(arg0, size, keySet));
return pushArgs(keySet, others...);
}
/// Terminal case of pushArgs
outcome::checked<void, StringError> pushArgs(KeySet &keySet) {
return checkAllArgs(keySet);
}
private:
std::unique_ptr<ValueExporterInterface> getExporter(KeySet &keySet) {
if (isSimulated()) {
return std::make_unique<SimulatedValueExporter>(
keySet.clientParameters());
} else {
return std::make_unique<ValueExporter>(keySet, keySet.clientParameters());
}
}
/// Store buffers of ciphertexts
std::vector<ScalarOrTensorData> values;
/// Whether it a simulates an encrypted argument or not
bool simulation;
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,194 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
#include <cassert>
#include <memory>
#include <vector>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Common/Error.h"
struct Csprng;
struct CsprngVtable;
namespace concretelang {
namespace clientlib {
class CSPRNG {
public:
struct Csprng *ptr;
const struct CsprngVtable *vtable;
CSPRNG() = delete;
CSPRNG(CSPRNG &) = delete;
CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) {
assert(ptr != nullptr);
other.ptr = nullptr;
};
CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){};
};
class ConcreteCSPRNG : public CSPRNG {
public:
ConcreteCSPRNG(__uint128_t seed);
ConcreteCSPRNG() = delete;
ConcreteCSPRNG(ConcreteCSPRNG &) = delete;
ConcreteCSPRNG(ConcreteCSPRNG &&other);
~ConcreteCSPRNG();
};
/// @brief LweSecretKey implements tools for manipulating lwe secret key on
/// client.
class LweSecretKey {
std::shared_ptr<std::vector<uint64_t>> _buffer;
LweSecretKeyParam _parameters;
public:
LweSecretKey() = delete;
LweSecretKey(LweSecretKeyParam &parameters, CSPRNG &csprng);
LweSecretKey(std::shared_ptr<std::vector<uint64_t>> buffer,
LweSecretKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
/// @brief Encrypt the plaintext to the lwe ciphertext buffer.
void encrypt(uint64_t *ciphertext, uint64_t plaintext, double variance,
CSPRNG &csprng) const;
/// @brief Decrypt the ciphertext to the plaintext
void decrypt(const uint64_t *ciphertext, uint64_t &plaintext) const;
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the keyswicth key.
LweSecretKeyParam parameters() const { return this->_parameters; }
/// @brief Returns the lwe dimension of the secret key.
size_t dimension() const { return parameters().dimension; }
};
/// @brief LweKeyswitchKey implements tools for manipulating keyswitch key on
/// client.
class LweKeyswitchKey {
private:
std::shared_ptr<std::vector<uint64_t>> _buffer;
KeyswitchKeyParam _parameters;
public:
LweKeyswitchKey() = delete;
LweKeyswitchKey(KeyswitchKeyParam &parameters, LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng);
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
KeyswitchKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the keyswicth key.
KeyswitchKeyParam parameters() const { return this->_parameters; }
};
/// @brief LweBootstrapKey implements tools for manipulating bootstrap key on
/// client.
class LweBootstrapKey {
private:
std::shared_ptr<std::vector<uint64_t>> _buffer;
BootstrapKeyParam _parameters;
public:
LweBootstrapKey() = delete;
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
BootstrapKeyParam &parameters)
: _buffer(buffer), _parameters(parameters){};
LweBootstrapKey(BootstrapKeyParam &parameters, LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng);
///// @brief Returns the buffer that hold the bootstrap key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the bootsrap key.
BootstrapKeyParam parameters() const { return this->_parameters; }
};
/// @brief PackingKeyswitchKey implements tools for manipulating privat packing
/// keyswitch key on client.
class PackingKeyswitchKey {
private:
std::shared_ptr<std::vector<uint64_t>> _buffer;
PackingKeyswitchKeyParam _parameters;
public:
PackingKeyswitchKey() = delete;
PackingKeyswitchKey(PackingKeyswitchKeyParam &parameters,
LweSecretKey &inputKey, LweSecretKey &outputKey,
CSPRNG &csprng);
PackingKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
PackingKeyswitchKeyParam parameters)
: _buffer(buffer), _parameters(parameters){};
/// @brief Returns the buffer that hold the keyswitch key.
const uint64_t *buffer() const { return _buffer->data(); }
size_t size() const { return _buffer->size(); }
/// @brief Returns the parameters of the keyswicth key.
PackingKeyswitchKeyParam parameters() const { return this->_parameters; }
};
// =============================================
/// Evalution keys required for execution.
class EvaluationKeys {
private:
std::vector<LweKeyswitchKey> keyswitchKeys;
std::vector<LweBootstrapKey> bootstrapKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
public:
EvaluationKeys() = delete;
EvaluationKeys(const std::vector<LweKeyswitchKey> keyswitchKeys,
const std::vector<LweBootstrapKey> bootstrapKeys,
const std::vector<PackingKeyswitchKey> packingKeyswitchKeys)
: keyswitchKeys(keyswitchKeys), bootstrapKeys(bootstrapKeys),
packingKeyswitchKeys(packingKeyswitchKeys) {}
const LweKeyswitchKey &getKeyswitchKey(size_t id) const {
return this->keyswitchKeys[id];
}
const std::vector<LweKeyswitchKey> getKeyswitchKeys() const {
return this->keyswitchKeys;
}
const LweBootstrapKey &getBootstrapKey(size_t id) const {
return bootstrapKeys[id];
}
const std::vector<LweBootstrapKey> getBootstrapKeys() const {
return this->bootstrapKeys;
}
const PackingKeyswitchKey &getPackingKeyswitchKey(size_t id) const {
return this->packingKeyswitchKeys[id];
};
const std::vector<PackingKeyswitchKey> getPackingKeyswitchKeys() const {
return this->packingKeyswitchKeys;
}
};
// =============================================
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,128 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_KEYSET_H_
#define CONCRETELANG_CLIENTLIB_KEYSET_H_
#include <memory>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/DFRuntime.hpp"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
class KeySet {
public:
KeySet(ClientParameters clientParameters, CSPRNG &&csprng)
: csprng(std::move(csprng)), _clientParameters(clientParameters){};
KeySet(KeySet &other) = delete;
/// Generate a KeySet from a ClientParameters specification.
static outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(ClientParameters clientParameters, CSPRNG &&csprng);
/// Create a KeySet from a set of given keys
static outcome::checked<std::unique_ptr<KeySet>, StringError> fromKeys(
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
std::vector<LweBootstrapKey> bootstrapKeys,
std::vector<LweKeyswitchKey> keyswitchKeys,
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng);
/// Returns the ClientParameters associated with the KeySet.
ClientParameters clientParameters() const { return _clientParameters; }
// isInputEncrypted return true if the input at the given pos is encrypted.
bool isInputEncrypted(size_t pos);
/// allocate a lwe ciphertext buffer for the argument at argPos, set the size
/// of the allocated buffer.
outcome::checked<void, StringError>
allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size);
/// encrypt the input to the ciphertext for the argument at argPos.
outcome::checked<void, StringError>
encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
/// isOuputEncrypted return true if the output at the given pos is encrypted.
bool isOutputEncrypted(size_t pos);
/// decrypt the ciphertext to the output for the argument at argPos.
outcome::checked<void, StringError>
decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output);
size_t numInputs() { return inputs.size(); }
size_t numOutputs() { return outputs.size(); }
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
/// @brief evaluationKeys returns the evaluation keys associate to this client
/// keyset. Those evaluations keys can be safely shared publicly
EvaluationKeys evaluationKeys();
const std::vector<LweSecretKey> &getSecretKeys() const;
const std::vector<LweBootstrapKey> &getBootstrapKeys() const;
const std::vector<LweKeyswitchKey> &getKeyswitchKeys() const;
const std::vector<PackingKeyswitchKey> &getPackingKeyswitchKeys() const;
protected:
outcome::checked<void, StringError>
generateSecretKey(LweSecretKeyParam param);
outcome::checked<void, StringError>
generateBootstrapKey(BootstrapKeyParam param);
outcome::checked<void, StringError>
generateKeyswitchKey(KeyswitchKeyParam param);
outcome::checked<void, StringError>
generatePackingKeyswitchKey(PackingKeyswitchKeyParam param);
outcome::checked<void, StringError> generateKeysFromParams();
outcome::checked<void, StringError> setupEncryptionMaterial();
friend class KeySetCache;
private:
CSPRNG csprng;
///////////////////////////////////////////////
// Keys mappings
std::vector<LweSecretKey> secretKeys;
std::vector<LweBootstrapKey> bootstrapKeys;
std::vector<LweKeyswitchKey> keyswitchKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
outcome::checked<LweSecretKey, StringError> findLweSecretKey(LweSecretKeyID);
///////////////////////////////////////////////
// Convenient positional mapping between positional gate en secret key
typedef std::vector<std::pair<CircuitGate, std::optional<LweSecretKey>>>
SecretKeyGateMapping;
outcome::checked<SecretKeyGateMapping, StringError>
mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates);
SecretKeyGateMapping inputs;
SecretKeyGateMapping outputs;
clientlib::ClientParameters _clientParameters;
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,43 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
#define CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
#include "concretelang/ClientLib/KeySet.h"
namespace concretelang {
namespace clientlib {
class KeySet;
class KeySetCache {
std::string backingDirectoryPath;
public:
KeySetCache(std::string backingDirectoryPath)
: backingDirectoryPath(backingDirectoryPath) {}
static outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(std::shared_ptr<KeySetCache> optionalCache, ClientParameters &params,
uint64_t seed_msb, uint64_t seed_lsb);
outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb);
private:
static outcome::checked<std::unique_ptr<KeySet>, StringError>
loadKeys(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb,
std::string folderPath);
outcome::checked<std::unique_ptr<KeySet>, StringError>
loadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb);
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,146 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
#define CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
#include <iostream>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/ClientLib/ValueDecrypter.h"
#include "concretelang/Common/Error.h"
namespace concretelang {
namespace serverlib {
class ServerLambda;
}
} // namespace concretelang
namespace mlir {
namespace concretelang {
class JITLambda;
}
} // namespace mlir
namespace concretelang {
namespace clientlib {
using concretelang::clientlib::ValueDecrypter;
using concretelang::error::StringError;
class EncryptedArguments;
/// PublicArguments will be sended to the server. It includes encrypted
/// arguments and public keys.
class PublicArguments {
public:
PublicArguments(const ClientParameters &clientParameters,
std::vector<clientlib::SharedScalarOrTensorData> &buffers);
~PublicArguments();
static outcome::checked<std::unique_ptr<PublicArguments>, StringError>
unserialize(const ClientParameters &expectedParams, std::istream &istream);
outcome::checked<void, StringError> serialize(std::ostream &ostream);
std::vector<SharedScalarOrTensorData> &getArguments() { return arguments; }
ClientParameters &getClientParameters() { return clientParameters; }
friend class ::concretelang::serverlib::ServerLambda;
friend class ::mlir::concretelang::JITLambda;
private:
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
ClientParameters clientParameters;
/// Store buffers of ciphertexts
std::vector<SharedScalarOrTensorData> arguments;
};
/// PublicResult is a result of a ServerLambda call which contains encrypted
/// results.
struct PublicResult {
PublicResult(const ClientParameters &clientParameters,
std::vector<SharedScalarOrTensorData> &&buffers = {})
: clientParameters(clientParameters), buffers(std::move(buffers)){};
PublicResult(PublicResult &) = delete;
/// @brief Return a value from the PublicResult
/// @param argPos The position of the value in the PublicResult
/// @return Either the value or an error if there are no value at this
/// position
outcome::checked<SharedScalarOrTensorData, StringError>
getValue(size_t argPos) {
if (argPos >= buffers.size()) {
return StringError("result #") << argPos << " does not exists";
}
return buffers[argPos];
}
/// Create a public result from buffers.
static std::unique_ptr<PublicResult>
fromBuffers(const ClientParameters &clientParameters,
std::vector<SharedScalarOrTensorData> &&buffers) {
return std::make_unique<PublicResult>(clientParameters, std::move(buffers));
}
/// Unserialize from an input stream inplace.
outcome::checked<void, StringError> unserialize(std::istream &istream);
/// Unserialize from an input stream returning a new PublicResult.
static outcome::checked<std::unique_ptr<PublicResult>, StringError>
unserialize(ClientParameters &expectedParams, std::istream &istream) {
auto publicResult = std::make_unique<PublicResult>(expectedParams);
OUTCOME_TRYV(publicResult->unserialize(istream));
return std::move(publicResult);
}
/// Serialize into an output stream.
outcome::checked<void, StringError> serialize(std::ostream &ostream);
/// Get the result at `pos` as a scalar. Decryption happens if the
/// result is encrypted.
template <typename T>
outcome::checked<T, StringError> asClearTextScalar(KeySet &keySet,
size_t pos) {
ValueDecrypter decrypter(keySet, clientParameters);
auto &data = buffers[pos].get();
return decrypter.template decrypt<T>(data, pos);
}
/// Get the result at `pos` as a vector. Decryption happens if the
/// result is encrypted.
template <typename T>
outcome::checked<std::vector<T>, StringError>
asClearTextVector(KeySet &keySet, size_t pos) {
ValueDecrypter decrypter(keySet, clientParameters);
return decrypter.template decryptTensor<T>(buffers[pos].get(), pos);
}
/// Return the shape of the clear tensor of a result.
outcome::checked<std::vector<int64_t>, StringError>
asClearTextShape(size_t pos) {
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
return gate.shape.dimensions;
}
// private: TODO tmp
friend class ::concretelang::serverlib::ServerLambda;
ClientParameters clientParameters;
std::vector<SharedScalarOrTensorData> buffers;
};
/// Helper function to convert from MemRefDescriptor to
/// TensorData
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
bool is_signed, void *allocated, void *aligned,
size_t offset, size_t *sizes, size_t *strides);
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,129 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
#define CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
#include <iostream>
#include <limits>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/Types.h"
namespace concretelang {
namespace clientlib {
// integers are not serialized as binary values even on a binary stream
// so we cannot rely on << operator directly
template <typename Word>
std::ostream &writeWord(std::ostream &ostream, Word word) {
ostream.write(reinterpret_cast<char *>(&(word)), sizeof(word));
assert(ostream.good());
return ostream;
}
template <typename Size>
std::ostream &writeSize(std::ostream &ostream, Size size) {
return writeWord(ostream, size);
}
// for sake of symetry
template <typename Word>
std::istream &readWord(std::istream &istream, Word &word) {
istream.read(reinterpret_cast<char *>(&(word)), sizeof(word));
assert(istream.good());
return istream;
}
template <typename Word>
std::istream &readWords(std::istream &istream, Word *words, size_t numWords) {
assert(std::numeric_limits<size_t>::max() / sizeof(*words) > numWords);
istream.read(reinterpret_cast<char *>(words), sizeof(*words) * numWords);
assert(istream.good());
return istream;
}
template <typename Size>
std::istream &readSize(std::istream &istream, Size &size) {
return readWord(istream, size);
}
template <typename Stream> bool incorrectMode(Stream &stream) {
auto binary = stream.flags() && std::ios::binary;
if (!binary) {
stream.setstate(std::ios::failbit);
}
return !binary;
}
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream);
outcome::checked<ScalarData, StringError>
unserializeScalarData(std::istream &istream);
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
std::ostream &ostream);
template <typename T>
std::ostream &serializeTensorDataRaw(const llvm::ArrayRef<size_t> &dimensions,
const llvm::ArrayRef<T> &values,
std::ostream &ostream) {
writeWord<uint64_t>(ostream, dimensions.size());
for (size_t dim : dimensions)
writeWord<int64_t>(ostream, dim);
writeWord<uint64_t>(ostream, sizeof(T) * 8);
writeWord<uint8_t>(ostream, std::is_signed<T>());
for (T val : values)
writeWord(ostream, val);
return ostream;
}
outcome::checked<TensorData, StringError>
unserializeTensorData(std::istream &istream);
std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
std::ostream &ostream);
outcome::checked<ScalarOrTensorData, StringError>
unserializeScalarOrTensorData(std::istream &istream);
std::ostream &serializeVectorOfScalarOrTensorData(
const std::vector<SharedScalarOrTensorData> &sotd, std::ostream &ostream);
outcome::checked<std::vector<SharedScalarOrTensorData>, StringError>
unserializeVectorOfScalarOrTensorData(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
LweSecretKey readLweSecretKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const LweKeyswitchKey &wrappedKsk);
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const LweBootstrapKey &wrappedBsk);
LweBootstrapKey readLweBootstrapKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const PackingKeyswitchKey &wrappedKsk);
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream, const KeySet &keySet);
std::unique_ptr<KeySet> readKeySet(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys);
EvaluationKeys readEvaluationKeys(std::istream &istream);
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,898 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_TYPES_H_
#define CONCRETELANG_CLIENTLIB_TYPES_H_
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <stddef.h>
#include <vector>
namespace concretelang {
namespace clientlib {
template <size_t N> struct MemRefDescriptor {
uint64_t *allocated;
uint64_t *aligned;
size_t offset;
size_t sizes[N];
size_t strides[N];
};
using decrypted_scalar_t = std::uint64_t;
using decrypted_tensor_1_t = std::vector<decrypted_scalar_t>;
using decrypted_tensor_2_t = std::vector<decrypted_tensor_1_t>;
using decrypted_tensor_3_t = std::vector<decrypted_tensor_2_t>;
template <size_t Rank> using encrypted_tensor_t = MemRefDescriptor<Rank>;
using encrypted_scalar_t = uint64_t *;
using encrypted_scalars_t = uint64_t *;
// Element types for `TensorData`
enum class ElementType { u64, i64, u32, i32, u16, i16, u8, i8 };
// Returns the width in bits of an integer whose width is a power of
// two that can hold values with at most `width` bits
static inline constexpr size_t getStorageWidth(size_t width) {
if (width > 64)
assert(false && "Unsupported scalar width");
if (width > 32) {
return 64;
} else if (width > 16) {
return 32;
} else if (width > 8) {
return 16;
} else {
return 8;
}
}
// Translates `sign` and `width` into an `ElementType`.
static inline ElementType getElementTypeFromWidthAndSign(size_t width,
bool sign) {
switch (getStorageWidth(width)) {
case 64:
return (sign) ? ElementType::i64 : ElementType::u64;
case 32:
return (sign) ? ElementType::i32 : ElementType::u32;
case 16:
return (sign) ? ElementType::i16 : ElementType::u16;
case 8:
return (sign) ? ElementType::i8 : ElementType::u8;
default:
assert(false && "Unsupported scalar width");
}
}
namespace {
// Returns the number of bits for an element type
static constexpr size_t getElementTypeWidth(ElementType t) {
switch (t) {
case ElementType::u64:
case ElementType::i64:
return 64;
case ElementType::u32:
case ElementType::i32:
return 32;
case ElementType::u16:
case ElementType::i16:
return 16;
case ElementType::u8:
case ElementType::i8:
return 8;
}
// Cannot happen
return 0;
}
// Returns `true` if the element type `t` designates a signed type,
// otherwise `false`.
static constexpr size_t getElementTypeSignedness(ElementType t) {
switch (t) {
case ElementType::u64:
case ElementType::u32:
case ElementType::u16:
case ElementType::u8:
return false;
case ElementType::i64:
case ElementType::i32:
case ElementType::i16:
case ElementType::i8:
return true;
}
// Cannot happen
return false;
}
// Returns `true` iff the element type `t` designates the smallest
// unsigned / signed (depending on `sign`) integer type that can hold
// values of up to `width` bits, otherwise false.
static inline bool checkElementTypeForWidthAndSign(ElementType t, size_t width,
bool sign) {
return getElementTypeFromWidthAndSign(getStorageWidth(width), sign) == t;
}
} // namespace
// Constants for the element types used for tensors representing
// encrypted data and data after decryption
constexpr ElementType EncryptedScalarElementType = ElementType::u64;
constexpr size_t EncryptedScalarElementWidth =
getElementTypeWidth(EncryptedScalarElementType);
using EncryptedScalarElement = uint64_t;
namespace detail {
namespace TensorData {
// Union used to store the pointer to the actual data of an instance
// of `TensorData`. Values are stored contiguously in memory in a
// `std::vector` whose element type corresponds to the element type of
// the tensor.
union value_vector_union {
std::vector<uint64_t> *u64;
std::vector<int64_t> *i64;
std::vector<uint32_t> *u32;
std::vector<int32_t> *i32;
std::vector<uint16_t> *u16;
std::vector<int16_t> *i16;
std::vector<uint8_t> *u8;
std::vector<int8_t> *i8;
};
// Function templates that would go into the class `TensorData`, but
// which need to declared in namespace scope, since specializations of
// templates on the return type cannot be done for member functions as
// per the C++ standard
template <typename T> T begin(union value_vector_union &vec);
template <typename T> T end(union value_vector_union &vec);
template <typename T> T cbegin(union value_vector_union &vec);
template <typename T> T cend(union value_vector_union &vec);
template <typename T> T getElements(union value_vector_union &vec);
template <typename T> T getConstElements(const union value_vector_union &vec);
template <typename T>
T getElementValue(union value_vector_union &vec, size_t idx,
ElementType elementType);
template <typename T>
T &getElementReference(union value_vector_union &vec, size_t idx,
ElementType elementType);
template <typename T>
T *getElementPointer(union value_vector_union &vec, size_t idx,
ElementType elementType);
// Specializations for the above templates
#define TENSORDATA_SPECIALIZE_FOR_ITERATOR(ELTY, SUFFIX) \
template <> \
inline std::vector<ELTY>::iterator begin(union value_vector_union &vec) { \
return vec.SUFFIX->begin(); \
} \
\
template <> \
inline std::vector<ELTY>::iterator end(union value_vector_union &vec) { \
return vec.SUFFIX->end(); \
} \
\
template <> \
inline std::vector<ELTY>::const_iterator cbegin( \
union value_vector_union &vec) { \
return vec.SUFFIX->cbegin(); \
} \
\
template <> \
inline std::vector<ELTY>::const_iterator cend( \
union value_vector_union &vec) { \
return vec.SUFFIX->cend(); \
} \
\
template <> \
inline std::vector<ELTY> &getElements(union value_vector_union &vec) { \
return *vec.SUFFIX; \
} \
\
template <> \
inline const std::vector<ELTY> &getConstElements( \
const union value_vector_union &vec) { \
return *vec.SUFFIX; \
}
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint64_t, u64)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int64_t, i64)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint32_t, u32)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int32_t, i32)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint16_t, u16)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int16_t, i16)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint8_t, u8)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int8_t, i8)
#define TENSORDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \
template <> \
inline ELTY getElementValue(union value_vector_union &vec, size_t idx, \
ElementType elementType) { \
assert(elementType == ElementType::SUFFIX); \
return (*vec.SUFFIX)[idx]; \
} \
\
template <> \
inline ELTY &getElementReference(union value_vector_union &vec, size_t idx, \
ElementType elementType) { \
assert(elementType == ElementType::SUFFIX); \
return (*vec.SUFFIX)[idx]; \
} \
\
template <> \
inline ELTY *getElementPointer(union value_vector_union &vec, size_t idx, \
ElementType elementType) { \
assert(elementType == ElementType::SUFFIX); \
return &(*vec.SUFFIX)[idx]; \
}
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64)
TENSORDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64)
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32)
TENSORDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32)
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16)
TENSORDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16)
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8)
TENSORDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8)
} // namespace TensorData
} // namespace detail
// Representation of a tensor with an arbitrary number of dimensions
class TensorData {
protected:
detail::TensorData::value_vector_union values;
ElementType elementType;
std::vector<size_t> dimensions;
size_t elementWidth;
/* Multi-dimensional, uninitialized, but preallocated tensor */
void initPreallocated(llvm::ArrayRef<size_t> dimensions,
ElementType elementType, size_t elementWidth,
bool sign) {
assert(checkElementTypeForWidthAndSign(elementType, elementWidth, sign) &&
"Incoherent parameters for element type, width and sign");
assert(dimensions.size() != 0);
size_t n = getNumElements(dimensions);
switch (elementType) {
case ElementType::u64:
this->values.u64 = new std::vector<uint64_t>(n);
break;
case ElementType::i64:
this->values.i64 = new std::vector<int64_t>(n);
break;
case ElementType::u32:
this->values.u32 = new std::vector<uint32_t>(n);
break;
case ElementType::i32:
this->values.i32 = new std::vector<int32_t>(n);
break;
case ElementType::u16:
this->values.u16 = new std::vector<uint16_t>(n);
break;
case ElementType::i16:
this->values.i16 = new std::vector<int16_t>(n);
break;
case ElementType::u8:
this->values.u8 = new std::vector<uint8_t>(n);
break;
case ElementType::i8:
this->values.i8 = new std::vector<int8_t>(n);
break;
}
this->dimensions.resize(dimensions.size());
this->elementWidth = elementWidth;
this->elementType = elementType;
std::copy(dimensions.begin(), dimensions.end(), this->dimensions.begin());
}
// Creates a vector<size_t> from an ArrayRef<T>
template <typename T>
static std::vector<size_t> toDimSpec(llvm::ArrayRef<T> dims) {
return std::vector<size_t>(dims.begin(), dims.end());
}
public:
// Returns the total number of elements of a tensor with the
// specified dimensions
template <typename T> static size_t getNumElements(T dimensions) {
size_t n = 1;
for (auto dim : dimensions)
n *= dim;
return n;
}
// Move constructor. Leaves `that` uninitialized.
TensorData(TensorData &&that)
: elementType(that.elementType), dimensions(std::move(that.dimensions)),
elementWidth(that.elementWidth) {
switch (that.elementType) {
case ElementType::u64:
this->values.u64 = that.values.u64;
that.values.u64 = nullptr;
break;
case ElementType::i64:
this->values.i64 = that.values.i64;
that.values.i64 = nullptr;
break;
case ElementType::u32:
this->values.u32 = that.values.u32;
that.values.u32 = nullptr;
break;
case ElementType::i32:
this->values.i32 = that.values.i32;
that.values.i32 = nullptr;
break;
case ElementType::u16:
this->values.u16 = that.values.u16;
that.values.u16 = nullptr;
break;
case ElementType::i16:
this->values.i16 = that.values.i16;
that.values.i16 = nullptr;
break;
case ElementType::u8:
this->values.u8 = that.values.u8;
that.values.u8 = nullptr;
break;
case ElementType::i8:
this->values.i8 = that.values.i8;
that.values.i8 = nullptr;
break;
}
}
// Constructor to build a multi-dimensional tensor with the
// corresponding element type. All elements are initialized with the
// default value of `0`.
TensorData(llvm::ArrayRef<size_t> dimensions, ElementType elementType,
size_t elementWidth) {
initPreallocated(dimensions, elementType, elementWidth,
getElementTypeSignedness(elementType));
}
TensorData(llvm::ArrayRef<int64_t> dimensions, ElementType elementType,
size_t elementWidth)
: TensorData(toDimSpec(dimensions), elementType, elementWidth) {}
// Constructor to build a multi-dimensional tensor with the element
// type corresponding to `elementWidth` and `sign`. All elements are
// initialized with the default value of `0`.
TensorData(llvm::ArrayRef<size_t> dimensions, size_t elementWidth, bool sign)
: TensorData(dimensions,
getElementTypeFromWidthAndSign(elementWidth, sign),
elementWidth) {}
TensorData(llvm::ArrayRef<int64_t> dimensions, size_t elementWidth, bool sign)
: TensorData(toDimSpec(dimensions), elementWidth, sign) {}
#define DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(ELTY, SUFFIX) \
/* Multi-dimensional, initialized tensor, values copied from */ \
/* `values` */ \
TensorData(llvm::ArrayRef<ELTY> values, llvm::ArrayRef<size_t> dimensions, \
size_t elementWidth) \
: dimensions(dimensions.begin(), dimensions.end()) { \
assert(checkElementTypeForWidthAndSign(ElementType::SUFFIX, elementWidth, \
std::is_signed<ELTY>()) && \
"wrong element type for width"); \
assert(dimensions.size() != 0); \
size_t n = getNumElements(dimensions); \
this->values.SUFFIX = new std::vector<ELTY>(n); \
this->elementType = ElementType::SUFFIX; \
this->bulkAssign(values); \
} \
\
/* One-dimensional, initialized tensor. Values are copied from */ \
/* `values` */ \
TensorData(llvm::ArrayRef<ELTY> values, size_t width) \
: TensorData(values, llvm::SmallVector<size_t, 1>{values.size()}, \
width) {}
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint64_t, u64)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int64_t, i64)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint32_t, u32)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int32_t, i32)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint16_t, u16)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int16_t, i16)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint8_t, u8)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int8_t, i8)
~TensorData() {
switch (this->elementType) {
case ElementType::u64:
delete values.u64;
break;
case ElementType::i64:
delete values.i64;
break;
case ElementType::u32:
delete values.u32;
break;
case ElementType::i32:
delete values.i32;
break;
case ElementType::u16:
delete values.u16;
break;
case ElementType::i16:
delete values.i16;
break;
case ElementType::u8:
delete values.u8;
break;
case ElementType::i8:
delete values.i8;
break;
}
}
// Returns the total number of elements of the tensor
size_t length() const { return getNumElements(this->dimensions); }
// Returns a vector with the size for each dimension of the tensor
const std::vector<size_t> &getDimensions() const { return this->dimensions; }
template <typename T> const std::vector<T> getDimensionsAs() const {
return std::vector<T>(this->dimensions.begin(), this->dimensions.end());
}
// Returns the number of dimensions
size_t getRank() const { return this->dimensions.size(); }
// Multi-dimensional access to a tensor element
template <typename T> T &operator[](llvm::ArrayRef<int64_t> index) {
// Number of dimensions must match
assert(index.size() == dimensions.size());
int64_t offset = 0;
int64_t multiplier = 1;
for (int64_t i = index.size() - 1; i > 0; i--) {
offset += index[i] * multiplier;
multiplier *= this->dimensions[i];
}
return detail::TensorData::getElementReference<T>(values, offset,
elementType);
}
// Iterator pointing to the first element of a flat representation
// of the tensor.
template <typename T> typename std::vector<T>::iterator begin() {
return detail::TensorData::begin<typename std::vector<T>::iterator>(values);
}
// Iterator pointing past the last element of a flat representation
// of the tensor.
template <typename T> typename std::vector<T>::iterator end() {
return detail::TensorData::end<typename std::vector<T>::iterator>(values);
}
// Const iterator pointing to the first element of a flat
// representation of the tensor.
template <typename T> typename std::vector<T>::iterator cbegin() {
return detail::TensorData::cbegin<typename std::vector<T>::iterator>(
values);
}
// Const iterator pointing past the last element of a flat
// representation of the tensor.
template <typename T> typename std::vector<T>::iterator cend() {
return detail::TensorData::cend<typename std::vector<T>::iterator>(values);
}
// Flat representation of the const tensor
template <typename T> const std::vector<T> &getElements() const {
return detail::TensorData::getConstElements<const std::vector<T> &>(values);
}
// Flat representation of the tensor
template <typename T> const std::vector<T> &getElements() {
return detail::TensorData::getElements<std::vector<T> &>(values);
}
// Returns the `index`-th value of a flat representation of the tensor
template <typename T> T getElementValue(size_t index) {
return detail::TensorData::getElementValue<T>(values, index, elementType);
}
// Returns a reference to the `index`-th value of a flat
// representation of the tensor
template <typename T> T &getElementReference(size_t index) {
return detail::TensorData::getElementReference<T>(values, index,
elementType);
}
// Returns a pointer to the `index`-th value of a flat
// representation of the tensor
template <typename T> T *getElementPointer(size_t index) {
return detail::TensorData::getElementPointer<T>(values, index, elementType);
}
// Returns a pointer to the `index`-th value of a flat
// representation of the tensor (const version)
template <typename T> const T *getElementPointer(size_t index) const {
return detail::TensorData::getElementPointer<T>(values, index, elementType);
}
// Returns a void pointer to the `index`-th value of a flat
// representation of the tensor
void *getOpaqueElementPointer(size_t index) {
switch (this->elementType) {
case ElementType::u64:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<uint64_t>(values, index,
elementType));
case ElementType::i64:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<int64_t>(values, index,
elementType));
case ElementType::u32:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<uint32_t>(values, index,
elementType));
case ElementType::i32:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<int32_t>(values, index,
elementType));
case ElementType::u16:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<uint16_t>(values, index,
elementType));
case ElementType::i16:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<int16_t>(values, index,
elementType));
case ElementType::u8:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<uint8_t>(values, index,
elementType));
case ElementType::i8:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<int8_t>(values, index,
elementType));
}
assert(false && "Unknown element type");
}
// Returns the element type of the tensor
ElementType getElementType() const { return this->elementType; }
// Returns the actual width in bits of a data element (i.e., the
// width specified upon construction and not the storage width of an
// element)
size_t getElementWidth() const { return this->elementWidth; }
// Returns the size of a tensor element in bytes (i.e., the storage width in
// bytes)
size_t getElementSize() const {
switch (this->elementType) {
case ElementType::u64:
case ElementType::i64:
return 8;
case ElementType::u32:
case ElementType::i32:
return 4;
case ElementType::u16:
case ElementType::i16:
return 2;
case ElementType::u8:
case ElementType::i8:
return 1;
}
}
// Returns `true` if elements are signed, otherwise `false`
bool getElementSignedness() const {
switch (this->elementType) {
case ElementType::u64:
case ElementType::u32:
case ElementType::u16:
case ElementType::u8:
return false;
case ElementType::i64:
case ElementType::i32:
case ElementType::i16:
case ElementType::i8:
return true;
}
}
// Returns the total number of elements of the tensor
size_t getNumElements() const { return getNumElements(this->dimensions); }
// Copy all elements from `values` to the tensor. Note that this
// does not append values to the tensor, but overwrites existing
// values.
template <typename T> void bulkAssign(llvm::ArrayRef<T> values) {
assert(values.size() <= this->getNumElements());
switch (this->elementType) {
case ElementType::u64:
std::copy(values.begin(), values.end(), this->values.u64->begin());
break;
case ElementType::i64:
std::copy(values.begin(), values.end(), this->values.i64->begin());
break;
case ElementType::u32:
std::copy(values.begin(), values.end(), this->values.u32->begin());
break;
case ElementType::i32:
std::copy(values.begin(), values.end(), this->values.i32->begin());
break;
case ElementType::u16:
std::copy(values.begin(), values.end(), this->values.u16->begin());
break;
case ElementType::i16:
std::copy(values.begin(), values.end(), this->values.i16->begin());
break;
case ElementType::u8:
std::copy(values.begin(), values.end(), this->values.u8->begin());
break;
case ElementType::i8:
std::copy(values.begin(), values.end(), this->values.i8->begin());
break;
}
}
// Copies all elements of a flat representation of the tensor to the
// positions starting with the iterator `start`.
template <typename IT> void copy(IT start) const {
switch (this->elementType) {
case ElementType::u64:
std::copy(this->values.u64->cbegin(), this->values.u64->cend(), start);
break;
case ElementType::i64:
std::copy(this->values.i64->cbegin(), this->values.i64->cend(), start);
break;
case ElementType::u32:
std::copy(this->values.u32->cbegin(), this->values.u32->cend(), start);
break;
case ElementType::i32:
std::copy(this->values.i32->cbegin(), this->values.i32->cend(), start);
break;
case ElementType::u16:
std::copy(this->values.u16->cbegin(), this->values.u16->cend(), start);
break;
case ElementType::i16:
std::copy(this->values.i16->cbegin(), this->values.i16->cend(), start);
break;
case ElementType::u8:
std::copy(this->values.u8->cbegin(), this->values.u8->cend(), start);
break;
case ElementType::i8:
std::copy(this->values.i8->cbegin(), this->values.i8->cend(), start);
break;
}
}
// Returns a flat representation of the tensor with elements
// converted to the type `T`
template <typename T> std::vector<T> asFlatVector() const {
std::vector<T> ret(getNumElements());
this->copy(ret.begin());
return ret;
}
// Returns a void pointer to the first element of a flat
// representation of the tensor
void *getValuesAsOpaquePointer() const {
switch (this->elementType) {
case ElementType::u64:
return static_cast<void *>(values.u64->data());
case ElementType::i64:
return static_cast<void *>(values.i64->data());
case ElementType::u32:
return static_cast<void *>(values.u32->data());
case ElementType::i32:
return static_cast<void *>(values.i32->data());
case ElementType::u16:
return static_cast<void *>(values.u16->data());
case ElementType::i16:
return static_cast<void *>(values.i16->data());
case ElementType::u8:
return static_cast<void *>(values.u8->data());
case ElementType::i8:
return static_cast<void *>(values.i8->data());
}
assert(false && "Unhandled element type");
}
};
namespace detail {
namespace ScalarData {
// Union representing a single scalar value
union scalar_union {
uint64_t u64;
int64_t i64;
uint32_t u32;
int32_t i32;
uint16_t u16;
int16_t i16;
uint8_t u8;
int8_t i8;
};
// Template + specializations that should be in ScalarData, but which need to be
// in namespace scope
template <typename T> T getValue(const union scalar_union &u, ElementType type);
#define SCALARDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \
template <> \
inline ELTY getValue(const union scalar_union &u, ElementType type) { \
assert(type == ElementType::SUFFIX); \
return u.SUFFIX; \
}
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64)
SCALARDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64)
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32)
SCALARDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32)
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16)
SCALARDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16)
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8)
SCALARDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8)
} // namespace ScalarData
} // namespace detail
// Class representing a single scalar value
class ScalarData {
public:
ScalarData(const ScalarData &s)
: type(s.type), value(s.value), width(s.width) {}
// Construction with a specific type and an actual width, but with a value
// provided in a generic `uint64_t`
ScalarData(uint64_t value, ElementType type, size_t width)
: type(type), width(width) {
assert(width <= getElementTypeWidth(type));
switch (type) {
case ElementType::u64:
this->value.u64 = value;
break;
case ElementType::i64:
this->value.i64 = value;
break;
case ElementType::u32:
this->value.u32 = value;
break;
case ElementType::i32:
this->value.i32 = value;
break;
case ElementType::u16:
this->value.u16 = value;
break;
case ElementType::i16:
this->value.i16 = value;
break;
case ElementType::u8:
this->value.u8 = value;
break;
case ElementType::i8:
this->value.i8 = value;
break;
}
}
// Construction with a specific type determined by `sign` and
// `width`, but value provided in a generic `uint64_t`
ScalarData(uint64_t value, bool sign, size_t width)
: ScalarData(value, getElementTypeFromWidthAndSign(width, sign), width) {}
#define DEF_SCALAR_DATA_CONSTRUCTOR(ELTY, SUFFIX) \
ScalarData(ELTY value) \
: type(ElementType::SUFFIX), \
width(getElementTypeWidth(ElementType::SUFFIX)) { \
this->value.SUFFIX = value; \
}
// Construction from specific value type
DEF_SCALAR_DATA_CONSTRUCTOR(uint64_t, u64)
DEF_SCALAR_DATA_CONSTRUCTOR(int64_t, i64)
DEF_SCALAR_DATA_CONSTRUCTOR(uint32_t, u32)
DEF_SCALAR_DATA_CONSTRUCTOR(int32_t, i32)
DEF_SCALAR_DATA_CONSTRUCTOR(uint16_t, u16)
DEF_SCALAR_DATA_CONSTRUCTOR(int16_t, i16)
DEF_SCALAR_DATA_CONSTRUCTOR(uint8_t, u8)
DEF_SCALAR_DATA_CONSTRUCTOR(int8_t, i8)
template <typename T> T getValue() const {
return detail::ScalarData::getValue<T>(value, type);
}
// Retrieves the value as a generic `uint64_t`
uint64_t getValueAsU64() const {
size_t width = getElementTypeWidth(type);
if (width == 64)
return value.u64;
uint64_t mask = ((uint64_t)1 << width) - 1;
uint64_t val = value.u64 & mask;
return val;
}
ElementType getType() const { return type; }
size_t getWidth() const { return width; }
protected:
ElementType type;
union detail::ScalarData::scalar_union value;
size_t width;
};
// Variant for TensorData and ScalarData
class ScalarOrTensorData {
protected:
std::unique_ptr<ScalarData> scalar;
std::unique_ptr<TensorData> tensor;
public:
ScalarOrTensorData(const ScalarOrTensorData &td) = delete;
ScalarOrTensorData(ScalarOrTensorData &&td)
: scalar(std::move(td.scalar)), tensor(std::move(td.tensor)) {}
ScalarOrTensorData(TensorData &&td)
: scalar(nullptr), tensor(std::make_unique<TensorData>(std::move(td))) {}
ScalarOrTensorData(const ScalarData &s)
: scalar(std::make_unique<ScalarData>(s)), tensor(nullptr) {}
bool isTensor() const { return tensor != nullptr; }
bool isScalar() const { return scalar != nullptr; }
ScalarData &getScalar() {
assert(scalar != nullptr &&
"Attempt to get a scalar value from variant that is a tensor");
return *scalar;
}
const ScalarData &getScalar() const {
assert(scalar != nullptr &&
"Attempt to get a scalar value from variant that is a tensor");
return *scalar;
}
TensorData &getTensor() {
assert(tensor != nullptr &&
"Attempt to get a tensor value from variant that is a scalar");
return *tensor;
}
const TensorData &getTensor() const {
assert(tensor != nullptr &&
"Attempt to get a tensor value from variant that is a scalar");
return *tensor;
}
};
struct SharedScalarOrTensorData {
std::shared_ptr<ScalarOrTensorData> inner;
SharedScalarOrTensorData(std::shared_ptr<ScalarOrTensorData> inner)
: inner{inner} {}
SharedScalarOrTensorData(ScalarOrTensorData &&inner)
: inner{std::make_shared<ScalarOrTensorData>(std::move(inner))} {}
ScalarOrTensorData &get() const { return *this->inner; }
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,239 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_VALUE_DECRYPTER_H
#define CONCRETELANG_CLIENTLIB_VALUE_DECRYPTER_H
#include <iostream>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/Common/Error.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
class ValueDecrypterInterface {
protected:
virtual outcome::checked<uint64_t, StringError>
decryptValue(size_t argPos, uint64_t *ciphertext) = 0;
/// Size of the low-level ciphertext, taking into account the CRT if used
virtual int64_t ciphertextSize(CircuitGate &gate) = 0;
/// Output gate at position `argPos`
virtual outcome::checked<CircuitGate, StringError>
outputGate(size_t argPos) = 0;
/// Whether the value decrypter is simulating encryption
virtual bool isSimulated() = 0;
/// Whether argument at pos `argPos` is encrypted or not
virtual outcome::checked<bool, StringError> isEncrypted(size_t argPos) {
OUTCOME_TRY(auto gate, outputGate(argPos));
return gate.isEncrypted();
}
public:
virtual ~ValueDecrypterInterface() = default;
/// @brief Transforms a FHE value into a clear scalar value
/// @tparam T The type of the clear scalar value
/// @param value The value to decrypt
/// @param pos The position of the argument
/// @return Either the decrypted value or an error if the gate doesn't match
/// the expected result.
template <typename T>
outcome::checked<T, StringError> decrypt(ScalarOrTensorData &value,
size_t pos) {
OUTCOME_TRY(auto encrypted, isEncrypted(pos));
if (!encrypted)
return value.getScalar().getValue<T>();
if (isSimulated()) {
OUTCOME_TRY(auto gate, outputGate(pos));
auto crtVec = gate.encryption->encoding.crt;
if (crtVec.empty()) {
// value is a scalar
auto ciphertext = value.getScalar().getValue<uint64_t>();
OUTCOME_TRY(auto decrypted, decryptValue(pos, &ciphertext));
return (T)decrypted;
} else {
// value is a tensor (crt)
auto &buffer = value.getTensor();
auto ciphertext = buffer.getOpaqueElementPointer(0);
OUTCOME_TRY(
auto decrypted,
decryptValue(pos, reinterpret_cast<uint64_t *>(ciphertext)));
return (T)decrypted;
}
}
auto &buffer = value.getTensor();
auto ciphertext = buffer.getOpaqueElementPointer(0);
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
// FIXME: this may break alignment restrictions on some
// architectures
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
OUTCOME_TRY(auto decrypted, decryptValue(pos, ciphertextu64));
return (T)decrypted;
}
/// @brief Transforms a FHE value into a vector of clear value
/// @tparam T The type of the clear scalar value
/// @param value The value to decrypt
/// @param pos The position of the argument
/// @return Either the decrypted value or an error if the gate doesn't match
/// the expected result.
template <typename T>
outcome::checked<std::vector<T>, StringError>
decryptTensor(ScalarOrTensorData &value, size_t pos) {
OUTCOME_TRY(auto encrypted, isEncrypted(pos));
if (!encrypted)
return value.getTensor().asFlatVector<T>();
auto &buffer = value.getTensor();
OUTCOME_TRY(auto gate, outputGate(pos));
auto lweSize = ciphertextSize(gate);
std::vector<T> decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize);
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
// FIXME: this may break alignment restrictions on some
// architectures
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
OUTCOME_TRY(auto decrypted, decryptValue(pos, ciphertextu64));
decryptedValues[i] = decrypted;
}
return decryptedValues;
}
/// Return the shape of the clear tensor of a result.
outcome::checked<std::vector<int64_t>, StringError> getShape(size_t pos) {
OUTCOME_TRY(auto gate, outputGate(pos));
return gate.shape.dimensions;
}
};
/// @brief allows to transform a serializable value into a clear value
class ValueDecrypter : public ValueDecrypterInterface {
public:
ValueDecrypter(KeySet &keySet, ClientParameters clientParameters)
: _keySet(keySet), _clientParameters(clientParameters) {}
protected:
outcome::checked<uint64_t, StringError>
decryptValue(size_t argPos, uint64_t *ciphertext) override {
uint64_t decrypted;
OUTCOME_TRYV(_keySet.decrypt_lwe(0, ciphertext, decrypted));
return decrypted;
};
bool isSimulated() override { return false; }
outcome::checked<CircuitGate, StringError>
outputGate(size_t argPos) override {
return _clientParameters.ouput(argPos);
}
int64_t ciphertextSize(CircuitGate &gate) override {
return _clientParameters.lweBufferSize(gate);
}
private:
KeySet &_keySet;
ClientParameters _clientParameters;
};
class SimulatedValueDecrypter : public ValueDecrypterInterface {
public:
SimulatedValueDecrypter(ClientParameters clientParameters)
: _clientParameters(clientParameters) {}
protected:
// TODO: a lot of this logic can be factorized when moving
// `KeySet::decrypt_lwe` into the LWE ValueDecyrpter
outcome::checked<uint64_t, StringError>
decryptValue(size_t argPos, uint64_t *ciphertext) override {
uint64_t output;
OUTCOME_TRY(auto gate, outputGate(argPos));
auto encoding = gate.encryption->encoding;
auto precision = encoding.precision;
auto crtVec = gate.encryption->encoding.crt;
if (crtVec.empty()) {
output = *ciphertext;
output >>= (64 - precision - 2);
auto carry = output % 2;
uint64_t mod = (((uint64_t)1) << (precision + 1));
output = ((output >> 1) + carry) % mod;
// Further decode signed integers.
if (encoding.isSigned) {
uint64_t maxPos = (((uint64_t)1) << (precision - 1));
if (output >= maxPos) { // The output is actually negative.
// Set the preceding bits to zero
output |= UINT64_MAX << precision;
// This makes sure when the value is cast to int64, it has the correct
// value
};
}
} else {
// Decrypt and decode remainders
std::vector<int64_t> remainders;
for (auto modulus : crtVec) {
output = *ciphertext;
auto plaintext = crt::decode(output, modulus);
remainders.push_back(plaintext);
// each ciphertext is a scalar
ciphertext = ciphertext + 1;
}
output = crt::iCrt(crtVec, remainders);
// Further decode signed integers
if (encoding.isSigned) {
uint64_t maxPos = 1;
for (auto prime : crtVec) {
maxPos *= prime;
}
maxPos /= 2;
if (output >= maxPos) {
output -= maxPos * 2;
}
}
}
return output;
}
bool isSimulated() override { return true; }
outcome::checked<CircuitGate, StringError>
outputGate(size_t argPos) override {
return _clientParameters.ouput(argPos);
}
/// @brief Ciphertext size in simulation
/// When using CRT encoding, it's the number of blocks, otherwise, it's just 1
/// scalar
/// @param gate
/// @return number of scalars to represent one input
int64_t ciphertextSize(CircuitGate &gate) override {
// ciphertext in simulation are only scalars
assert(gate.encryption.has_value());
auto crtSize = gate.encryption->encoding.crt.size();
return crtSize == 0 ? 1 : crtSize;
}
private:
ClientParameters _clientParameters;
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,283 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_VALUE_EXPORTER_H
#define CONCRETELANG_CLIENTLIB_VALUE_EXPORTER_H
#include <ostream>
#include "boost/outcome.h"
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/simulation.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
class ValueExporterInterface {
protected:
virtual outcome::checked<void, StringError> encryptValue(CircuitGate &gate,
size_t argPos,
uint64_t *ciphertext,
uint64_t input) = 0;
/// Encrypt and export a 64bits integer to a serializale value
virtual outcome::checked<ScalarOrTensorData, StringError>
exportEncryptValue(uint64_t arg, CircuitGate &gate, size_t argPos) = 0;
/// Shape of the low-level buffer
virtual std::vector<int64_t> bufferShape(CircuitGate &gate) = 0;
/// Size of the low-level ciphertext, taking into account the CRT if used
virtual int64_t ciphertextSize(CircuitGate &gate) = 0;
/// Input gate at position `argPos`
virtual outcome::checked<CircuitGate, StringError>
inputGate(size_t argPos) = 0;
public:
virtual ~ValueExporterInterface() = default;
/// @brief Export a scalar 64 bits integer to a concreteprocol::Value
/// @param arg An 64 bits integer
/// @param argPos The position of the argument to export
/// @return Either the exported value ready to be sent to the server or an
/// error if the gate doesn't match the expected argument.
outcome::checked<ScalarOrTensorData, StringError> exportValue(uint64_t arg,
size_t argPos) {
OUTCOME_TRY(auto gate, inputGate(argPos));
if (gate.shape.size != 0) {
return StringError("argument #") << argPos << " is not a scalar";
}
if (gate.encryption.has_value()) {
return exportEncryptValue(arg, gate, argPos);
}
return exportClearValue(arg);
}
/// @brief Export a tensor like buffer of values to a serializable value
/// @tparam T The type of values hold by the buffer
/// @param arg A pointer to a memory area where the values are stored
/// @param shape The shape of the tensor
/// @param argPos The position of the argument to export
/// @return Either the exported value ready to be sent to the server or an
/// error if the gate doesn't match the expected argument.
template <typename T>
outcome::checked<ScalarOrTensorData, StringError>
exportValue(const T *arg, llvm::ArrayRef<int64_t> shape, size_t argPos) {
OUTCOME_TRY(auto gate, inputGate(argPos));
OUTCOME_TRYV(checkShape(shape, gate.shape, argPos));
if (gate.encryption.has_value()) {
return exportEncryptTensor(arg, shape, gate, argPos);
}
return exportClearTensor(arg, shape, gate);
}
protected:
/// Export a 64bits integer to a serializable value
virtual outcome::checked<ScalarOrTensorData, StringError>
exportClearValue(uint64_t arg) {
return ScalarData(arg);
}
/// Export a tensor like buffer to a serializable value
template <typename T>
outcome::checked<ScalarOrTensorData, StringError>
exportClearTensor(const T *arg, llvm::ArrayRef<int64_t> shape,
CircuitGate &gate) {
auto bitsPerValue = bitWidthAsWord(gate.shape.width);
auto sizes = bufferShape(gate);
TensorData td(sizes, bitsPerValue, gate.shape.sign);
llvm::ArrayRef<T> values(arg, TensorData::getNumElements(sizes));
td.bulkAssign(values);
return std::move(td);
}
/// Export and encrypt a tensor like buffer to a serializable value
template <typename T>
outcome::checked<ScalarOrTensorData, StringError>
exportEncryptTensor(const T *arg, llvm::ArrayRef<int64_t> shape,
CircuitGate &gate, size_t argPos) {
// Create and allocate the TensorData that will holds encrypted values
auto sizes = bufferShape(gate);
TensorData td(sizes, EncryptedScalarElementType,
EncryptedScalarElementWidth);
// Iterate over values and encrypt at the right place the value
auto lweSize = ciphertextSize(gate);
for (size_t i = 0, offset = 0; i < gate.shape.size;
i++, offset += lweSize) {
OUTCOME_TRYV(encryptValue(
gate, argPos, td.getElementPointer<uint64_t>(offset), arg[i]));
}
return std::move(td);
}
private:
static outcome::checked<void, StringError>
checkShape(llvm::ArrayRef<int64_t> shape, CircuitGateShape expected,
size_t argPos) {
// Check the shape of tensor
if (expected.dimensions.empty()) {
return StringError("argument #") << argPos << "is not a tensor";
}
if (shape.size() != expected.dimensions.size()) {
return StringError("argument #")
<< argPos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << expected.dimensions.size();
}
// Check shape
for (size_t i = 0; i < shape.size(); i++) {
if (shape[i] != expected.dimensions[i]) {
return StringError("argument #")
<< argPos << " has not the expected dimension #" << i
<< " , got " << shape[i] << " expected "
<< expected.dimensions[i];
}
}
return outcome::success();
}
};
/// @brief The ArgumentsExporter allows to transform clear
/// arguments to the one expected by a server lambda.
class ValueExporter : public ValueExporterInterface {
public:
/// @brief
/// @param keySet
/// @param clientParameters
// TODO: Get rid of the reference here could make troubles (see for KeySet
// copy constructor or shared pointers)
ValueExporter(KeySet &keySet, ClientParameters clientParameters)
: _keySet(keySet), _clientParameters(clientParameters) {}
protected:
outcome::checked<void, StringError> encryptValue(CircuitGate &gate,
size_t argPos,
uint64_t *ciphertext,
uint64_t input) override {
return _keySet.encrypt_lwe(argPos, ciphertext, input);
}
outcome::checked<CircuitGate, StringError> inputGate(size_t argPos) override {
return _clientParameters.input(argPos);
}
std::vector<int64_t> bufferShape(CircuitGate &gate) override {
return _clientParameters.bufferShape(gate);
}
int64_t ciphertextSize(CircuitGate &gate) override {
return _clientParameters.lweBufferSize(gate);
}
/// Encrypt and export a 64bits integer to a serializale value
outcome::checked<ScalarOrTensorData, StringError>
exportEncryptValue(uint64_t arg, CircuitGate &gate, size_t argPos) override {
std::vector<int64_t> shape = _clientParameters.bufferShape(gate);
// Create and allocate the TensorData that will holds encrypted value
TensorData td(shape, clientlib::EncryptedScalarElementType,
clientlib::EncryptedScalarElementWidth);
// Encrypt the value
OUTCOME_TRYV(
encryptValue(gate, argPos, td.getElementPointer<uint64_t>(0), arg));
return std::move(td);
}
private:
KeySet &_keySet;
ClientParameters _clientParameters;
};
/// @brief The SimulatedValueExporter allows to transform clear
/// arguments to the one expected by a server lambda during simulation.
class SimulatedValueExporter : public ValueExporterInterface {
public:
SimulatedValueExporter(ClientParameters clientParameters)
: _clientParameters(clientParameters), csprng(0) {}
protected:
outcome::checked<void, StringError> encryptValue(CircuitGate &gate,
size_t argPos,
uint64_t *ciphertext,
uint64_t input) override {
auto crtVec = gate.encryption->encoding.crt;
if (crtVec.empty()) {
auto precision = gate.encryption->encoding.precision;
OUTCOME_TRY(auto skParam, _clientParameters.lweSecretKeyParam(gate));
auto lwe_dim = skParam.lweDimension();
auto encoded_input = input << (64 - (precision + 1));
*ciphertext =
sim_encrypt_lwe_u64(encoded_input, lwe_dim, (void *)csprng.ptr);
} else {
// Put each decomposition into a new ciphertext
auto product = concretelang::clientlib::crt::productOfModuli(crtVec);
for (auto modulus : crtVec) {
OUTCOME_TRY(auto skParam, _clientParameters.lweSecretKeyParam(gate));
auto lwe_dim = skParam.lweDimension();
auto plaintext = crt::encode(input, modulus, product);
*ciphertext =
sim_encrypt_lwe_u64(plaintext, lwe_dim, (void *)csprng.ptr);
// each ciphertext is a scalar
ciphertext = ciphertext + 1;
}
}
return outcome::success();
}
/// Simulate encrypt and export a 64bits integer to a serializale value
outcome::checked<ScalarOrTensorData, StringError>
exportEncryptValue(uint64_t arg, CircuitGate &gate, size_t argPos) override {
auto crtVec = gate.encryption->encoding.crt;
if (crtVec.empty()) {
uint64_t encValue = 0;
OUTCOME_TRYV(encryptValue(gate, argPos, &encValue, arg));
return ScalarData(encValue);
} else {
TensorData td(bufferShape(gate), clientlib::EncryptedScalarElementType,
clientlib::EncryptedScalarElementWidth);
OUTCOME_TRYV(
encryptValue(gate, argPos, td.getElementPointer<uint64_t>(0), arg));
return std::move(td);
}
}
outcome::checked<CircuitGate, StringError> inputGate(size_t argPos) override {
return _clientParameters.input(argPos);
}
std::vector<int64_t> bufferShape(CircuitGate &gate) override {
return _clientParameters.bufferShape(gate, true);
}
/// @brief Ciphertext size in simulation
/// When using CRT encoding, it's the number of blocks, otherwise, it's just 1
/// scalar
/// @param gate
/// @return number of scalars to represent one input
int64_t ciphertextSize(CircuitGate &gate) override {
// ciphertext in simulation are only scalars
assert(gate.encryption.has_value());
auto crtSize = gate.encryption->encoding.crt.size();
return crtSize == 0 ? 1 : crtSize;
}
private:
ClientParameters _clientParameters;
concretelang::clientlib::ConcreteCSPRNG csprng;
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -1,19 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_BITS_SIZE_H
#define CONCRETELANG_COMMON_BITS_SIZE_H
#include <stdlib.h>
namespace concretelang {
namespace common {
size_t bitWidthAsWord(size_t exactBitWidth);
}
} // namespace concretelang
#endif

View File

@@ -3,14 +3,13 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_CLIENTLIB_CRT_H_
#define CONCRETELANG_CLIENTLIB_CRT_H_
#ifndef CONCRETELANG_COMMON_CRT_H_
#define CONCRETELANG_COMMON_CRT_H_
#include <cstdint>
#include <vector>
namespace concretelang {
namespace clientlib {
namespace crt {
/// Compute the product of the moduli of the crt decomposition.
@@ -40,7 +39,6 @@ uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product);
uint64_t decode(uint64_t val, uint64_t modulus);
} // namespace crt
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -0,0 +1,514 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
//
// NOTE:
// -----
// To limit the size of the refactoring, we chose to not propagate the new
// client/server lib to the exterior APIs after it was finalized. This file only
// serves as a compatibility layer for exterior (python/rust/c) apis, for the
// time being.
#ifndef CONCRETELANG_COMMON_COMPAT
#define CONCRETELANG_COMMON_COMPAT
#include "capnp/serialize-packed.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/ClientLib/ClientLib.h"
#include "concretelang/Common/Keys.h"
#include "concretelang/Common/Keysets.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Common/Values.h"
#include "concretelang/ServerLib/ServerLib.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "kj/io.h"
#include "kj/std/iostream.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include <memory>
#include <stdexcept>
using concretelang::clientlib::ClientCircuit;
using concretelang::clientlib::ClientProgram;
using concretelang::keysets::Keyset;
using concretelang::keysets::KeysetCache;
using concretelang::keysets::ServerKeyset;
using concretelang::serverlib::ServerCircuit;
using concretelang::serverlib::ServerProgram;
using concretelang::values::TransportValue;
using concretelang::values::Value;
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
auto VARNAME = EXPECTED; \
if (auto err = VARNAME.takeError()) { \
throw std::runtime_error(llvm::toString(std::move(err))); \
}
#define CONCAT(a, b) CONCAT_INNER(a, b)
#define CONCAT_INNER(a, b) a##b
#define GET_OR_THROW_RESULT_(VARNAME, RESULT, MAYBE) \
auto MAYBE = RESULT; \
if (MAYBE.has_failure()) { \
throw std::runtime_error(MAYBE.as_failure().error().mesg); \
} \
VARNAME = MAYBE.value();
#define GET_OR_THROW_RESULT(VARNAME, RESULT) \
GET_OR_THROW_RESULT_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__))
#define EXPECTED_TRY_(lhs, rhs, maybe) \
auto maybe = rhs; \
if (auto err = maybe.takeError()) { \
return std::move(err); \
} \
lhs = *maybe;
#define EXPECTED_TRY(lhs, rhs) \
EXPECTED_TRY_(lhs, rhs, CONCAT(maybe, __COUNTER__))
template <typename T> llvm::Expected<T> outcomeToExpected(Result<T> outcome) {
if (outcome.has_failure()) {
return mlir::concretelang::StreamStringError(
outcome.as_failure().error().mesg);
} else {
return outcome.value();
}
}
// Every number sent by python through the API has a type `int64` that must be
// turned into the proper type expected by the ArgTransformers. This allows to
// get an extra transformer executed right before the ArgTransformer gets
// called.
std::function<Value(Value)>
getPythonTypeTransformer(const Message<concreteprotocol::GateInfo> &info) {
if (info.asReader().getTypeInfo().hasIndex()) {
return [=](Value input) {
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
return Value{(Tensor<uint64_t>)tensorInput};
};
} else if (info.asReader().getTypeInfo().hasPlaintext()) {
if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <=
8) {
return [=](Value input) {
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
return Value{(Tensor<uint8_t>)tensorInput};
};
}
if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <=
16) {
return [=](Value input) {
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
return Value{(Tensor<uint16_t>)tensorInput};
};
}
if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <=
32) {
return [=](Value input) {
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
return Value{(Tensor<uint32_t>)tensorInput};
};
}
if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <=
64) {
return [=](Value input) {
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
return Value{(Tensor<uint64_t>)tensorInput};
};
}
assert(false);
} else if (info.asReader().getTypeInfo().hasLweCiphertext()) {
if (info.asReader()
.getTypeInfo()
.getLweCiphertext()
.getEncoding()
.hasInteger() &&
info.asReader()
.getTypeInfo()
.getLweCiphertext()
.getEncoding()
.getInteger()
.getIsSigned()) {
return [=](Value input) { return input; };
} else {
return [=](Value input) {
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
return Value{(Tensor<uint64_t>)tensorInput};
};
}
} else {
assert(false);
}
};
namespace concretelang {
namespace serverlib {
/// A transition structure that preserver the current API of the library
/// support.
struct ServerLambda {
ServerCircuit circuit;
bool isSimulation;
};
} // namespace serverlib
namespace clientlib {
/// A transition structure that preserver the current API of the library
/// support.
struct LweSecretKeyParam {
Message<concreteprotocol::LweSecretKeyInfo> info;
};
/// A transition structure that preserver the current API of the library
/// support.
struct BootstrapKeyParam {
Message<concreteprotocol::LweBootstrapKeyInfo> info;
};
/// A transition structure that preserver the current API of the library
/// support.
struct KeyswitchKeyParam {
Message<concreteprotocol::LweKeyswitchKeyInfo> info;
};
/// A transition structure that preserver the current API of the library
/// support.
struct PackingKeyswitchKeyParam {
Message<concreteprotocol::PackingKeyswitchKeyInfo> info;
};
/// A transition structure that preserver the current API of the library
/// support.
struct Encoding {
Message<concreteprotocol::EncodingInfo> circuit;
};
/// A transition structure that preserver the current API of the library
/// support.
struct EncryptionGate {
Message<concreteprotocol::GateInfo> gateInfo;
};
/// A transition structure that preserver the current API of the library
/// support.
struct CircuitGate {
Message<concreteprotocol::GateInfo> gateInfo;
};
/// A transition structure that preserver the current API of the library
/// support.
struct ValueExporter {
ClientCircuit circuit;
};
/// A transition structure that preserver the current API of the library
/// support.
struct SimulatedValueExporter {
ClientCircuit circuit;
};
/// A transition structure that preserver the current API of the library
/// support.
struct ValueDecrypter {
ClientCircuit circuit;
};
/// A transition structure that preserver the current API of the library
/// support.
struct SimulatedValueDecrypter {
ClientCircuit circuit;
};
/// A transition structure that preserver the current API of the library
/// support.
struct PublicArguments {
std::vector<TransportValue> values;
};
/// A transition structure that preserver the current API of the library
/// support.
struct PublicResult {
std::vector<TransportValue> values;
};
/// A transition structure that preserver the current API of the library
/// support.
struct SharedScalarOrTensorData {
TransportValue value;
};
/// A transition structure that preserver the current API of the library
/// support.
struct ClientParameters {
Message<concreteprotocol::ProgramInfo> programInfo;
std::vector<LweSecretKeyParam> secretKeys;
std::vector<BootstrapKeyParam> bootstrapKeys;
std::vector<KeyswitchKeyParam> keyswitchKeys;
std::vector<PackingKeyswitchKeyParam> packingKeyswitchKeys;
};
/// A transition structure that preserver the current API of the library
/// support.
struct EvaluationKeys {
ServerKeyset keyset;
};
/// A transition structure that preserver the current API of the library
/// support.
struct KeySetCache {
KeysetCache keysetCache;
};
/// A transition structure that preserver the current API of the library
/// support.
struct KeySet {
Keyset keyset;
};
} // namespace clientlib
} // namespace concretelang
namespace mlir {
namespace concretelang {
/// A transition structure that preserves the current API of the library
/// support.
struct LambdaArgument {
::concretelang::values::Value value;
};
/// LibraryCompilationResult is the result of a compilation to a library.
struct LibraryCompilationResult {
/// The output directory path where the compilation artifacts have been
/// generated.
std::string outputDirPath;
std::string funcName;
};
class LibrarySupport {
public:
LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "",
bool generateSharedLib = true, bool generateStaticLib = true,
bool generateClientParameters = true,
bool generateCompilationFeedback = true)
: outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath),
generateSharedLib(generateSharedLib),
generateStaticLib(generateStaticLib),
generateClientParameters(generateClientParameters),
generateCompilationFeedback(generateCompilationFeedback) {}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(llvm::SourceMgr &program, CompilationOptions options) {
// Setup the compiler engine
auto context = CompilationContext::createShared();
concretelang::CompilerEngine engine(context);
engine.setCompilationOptions(options);
// Compile to a library
auto library =
engine.compile(program, outputPath, runtimeLibraryPath,
generateSharedLib, generateStaticLib,
generateClientParameters, generateCompilationFeedback);
if (auto err = library.takeError()) {
return std::move(err);
}
if (!options.mainFuncName.has_value()) {
return StreamStringError("Need to have a funcname to compile library");
}
this->funcName = options.mainFuncName.value();
auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = *options.mainFuncName;
return std::move(result);
}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(llvm::StringRef s, CompilationOptions options) {
std::unique_ptr<llvm::MemoryBuffer> mb =
llvm::MemoryBuffer::getMemBuffer(s);
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(mb), llvm::SMLoc());
return this->compile(sm, options);
}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(mlir::ModuleOp &program,
std::shared_ptr<mlir::concretelang::CompilationContext> &context,
CompilationOptions options) {
// Setup the compiler engine
concretelang::CompilerEngine engine(context);
engine.setCompilationOptions(options);
// Compile to a library
auto library =
engine.compile(program, outputPath, runtimeLibraryPath,
generateSharedLib, generateStaticLib,
generateClientParameters, generateCompilationFeedback);
if (auto err = library.takeError()) {
return std::move(err);
}
if (!options.mainFuncName.has_value()) {
return StreamStringError("Need to have a funcname to compile library");
}
this->funcName = options.mainFuncName.value();
auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = *options.mainFuncName;
return std::move(result);
}
/// Load the server lambda from the compilation result.
llvm::Expected<::concretelang::serverlib::ServerLambda>
loadServerLambda(LibraryCompilationResult &result, bool useSimulation) {
EXPECTED_TRY(auto programInfo, getProgramInfo());
EXPECTED_TRY(ServerProgram serverProgram,
outcomeToExpected(ServerProgram::load(programInfo.asReader(),
getSharedLibPath(),
useSimulation)));
EXPECTED_TRY(
ServerCircuit serverCircuit,
outcomeToExpected(serverProgram.getServerCircuit(result.funcName)));
return ::concretelang::serverlib::ServerLambda{serverCircuit,
useSimulation};
}
/// Load the client parameters from the compilation result.
llvm::Expected<::concretelang::clientlib::ClientParameters>
loadClientParameters(LibraryCompilationResult &result) {
EXPECTED_TRY(auto programInfo, getProgramInfo());
if (programInfo.asReader().getCircuits().size() > 1) {
return StreamStringError("ClientLambda: Provided program info contains "
"more than one circuit.");
}
if (programInfo.asReader().getCircuits()[0].getName() != result.funcName) {
return StreamStringError("Unexpected circuit name in program info");
}
auto secretKeys =
std::vector<::concretelang::clientlib::LweSecretKeyParam>();
for (auto key : programInfo.asReader().getKeyset().getLweSecretKeys()) {
secretKeys.push_back(::concretelang::clientlib::LweSecretKeyParam{key});
}
auto boostrapKeys =
std::vector<::concretelang::clientlib::BootstrapKeyParam>();
for (auto key : programInfo.asReader().getKeyset().getLweBootstrapKeys()) {
boostrapKeys.push_back(::concretelang::clientlib::BootstrapKeyParam{key});
}
auto keyswitchKeys =
std::vector<::concretelang::clientlib::KeyswitchKeyParam>();
for (auto key : programInfo.asReader().getKeyset().getLweKeyswitchKeys()) {
keyswitchKeys.push_back(
::concretelang::clientlib::KeyswitchKeyParam{key});
}
auto packingKeyswitchKeys =
std::vector<::concretelang::clientlib::PackingKeyswitchKeyParam>();
for (auto key :
programInfo.asReader().getKeyset().getPackingKeyswitchKeys()) {
packingKeyswitchKeys.push_back(
::concretelang::clientlib::PackingKeyswitchKeyParam{key});
}
return ::concretelang::clientlib::ClientParameters{
programInfo, secretKeys, boostrapKeys, keyswitchKeys,
packingKeyswitchKeys};
}
llvm::Expected<Message<concreteprotocol::ProgramInfo>> getProgramInfo() {
auto path = CompilerEngine::Library::getProgramInfoPath(outputPath);
std::ifstream file(path);
std::string content((std::istreambuf_iterator<char>(file)),
(std::istreambuf_iterator<char>()));
if (file.fail()) {
return StreamStringError("Cannot read file: ") << path;
}
auto output = Message<concreteprotocol::ProgramInfo>();
if (output.readJsonFromString(content).has_failure()) {
return StreamStringError("Cannot read json string.");
}
return output;
}
/// Load the the compilation result if circuit already compiled
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
loadCompilationResult() {
auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = funcName;
return std::move(result);
}
llvm::Expected<CompilationFeedback>
loadCompilationFeedback(LibraryCompilationResult &result) {
auto path = CompilerEngine::Library::getCompilationFeedbackPath(
result.outputDirPath);
auto feedback = CompilationFeedback::load(path);
if (feedback.has_error()) {
return StreamStringError(feedback.error().mesg);
}
return feedback.value();
}
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<::concretelang::clientlib::PublicResult>>
serverCall(::concretelang::serverlib::ServerLambda lambda,
::concretelang::clientlib::PublicArguments &args,
::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
if (lambda.isSimulation) {
return mlir::concretelang::StreamStringError(
"Tried to perform server call on simulation lambda.");
}
EXPECTED_TRY(auto output, outcomeToExpected(lambda.circuit.call(
evaluationKeys.keyset, args.values)));
::concretelang::clientlib::PublicResult res{output};
return std::make_unique<::concretelang::clientlib::PublicResult>(
std::move(res));
}
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<::concretelang::clientlib::PublicResult>>
simulate(::concretelang::serverlib::ServerLambda lambda,
::concretelang::clientlib::PublicArguments &args) {
if (!lambda.isSimulation) {
return mlir::concretelang::StreamStringError(
"Tried to perform simulation on execution lambda.");
}
EXPECTED_TRY(auto output,
outcomeToExpected(lambda.circuit.simulate(args.values)));
::concretelang::clientlib::PublicResult res{output};
return std::make_unique<::concretelang::clientlib::PublicResult>(
std::move(res));
}
/// Get path to shared library
std::string getSharedLibPath() {
return CompilerEngine::Library::getSharedLibraryPath(outputPath);
}
/// Get path to client parameters file
std::string getProgramInfoPath() {
return CompilerEngine::Library::getProgramInfoPath(outputPath);
}
private:
std::string outputPath;
std::string funcName;
std::string runtimeLibraryPath;
/// Flags to select generated artifacts
bool generateSharedLib;
bool generateStaticLib;
bool generateClientParameters;
bool generateCompilationFeedback;
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,46 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_CSPRNG_H
#define CONCRETELANG_COMMON_CSPRNG_H
#include <assert.h>
#include <stdlib.h>
struct Csprng;
struct CsprngVtable;
namespace concretelang {
namespace csprng {
class CSPRNG {
public:
struct Csprng *ptr;
const struct CsprngVtable *vtable;
CSPRNG() = delete;
CSPRNG(CSPRNG &) = delete;
CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) {
assert(ptr != nullptr);
other.ptr = nullptr;
};
CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){};
};
class ConcreteCSPRNG : public CSPRNG {
public:
ConcreteCSPRNG(__uint128_t seed);
ConcreteCSPRNG() = delete;
ConcreteCSPRNG(ConcreteCSPRNG &) = delete;
ConcreteCSPRNG(ConcreteCSPRNG &&other);
~ConcreteCSPRNG();
};
} // namespace csprng
} // namespace concretelang
#endif

View File

@@ -5,11 +5,13 @@
#ifndef CONCRETELANG_COMMON_ERROR_H
#define CONCRETELANG_COMMON_ERROR_H
#include "boost/outcome.h"
#include <string>
namespace concretelang {
namespace error {
/// The type of error used throughout the client/server libs.
class StringError {
public:
StringError(std::string mesg) : mesg(mesg){};
@@ -37,6 +39,9 @@ public:
}
};
/// A result type used throughout the client/server libs.
template <typename T> using Result = outcome::checked<T, StringError>;
} // namespace error
} // namespace concretelang

View File

@@ -0,0 +1,169 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_KEYS_H
#define CONCRETELANG_COMMON_KEYS_H
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Protocol.h"
#include <memory>
#include <stdlib.h>
#include <vector>
using concretelang::csprng::CSPRNG;
using concretelang::protocol::Message;
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
inline void getApproval() {
std::cerr << "DANGER: You are using an empty unsecure secret keys. Enter "
"\"y\" to continue: ";
char answer;
std::cin >> answer;
if (answer != 'y') {
std::abort();
}
}
#endif
namespace concretelang {
namespace keys {
/// An object representing an lwe Secret key
class LweSecretKey {
friend class Keyset;
friend class KeysetCache;
friend class LweBootstrapKey;
friend class LweKeyswitchKey;
friend class PackingKeyswitchKey;
public:
LweSecretKey(Message<concreteprotocol::LweSecretKeyInfo> info,
CSPRNG &csprng);
LweSecretKey() = delete;
LweSecretKey(std::shared_ptr<std::vector<uint64_t>> buffer,
Message<concreteprotocol::LweSecretKeyInfo> info)
: buffer(buffer), info(info){};
static LweSecretKey
fromProto(const Message<concreteprotocol::LweSecretKey> &proto);
Message<concreteprotocol::LweSecretKey> toProto() const;
const uint64_t *getRawPtr() const;
size_t getSize() const;
const Message<concreteprotocol::LweSecretKeyInfo> &getInfo() const;
const std::vector<uint64_t> &getBuffer() const;
typedef Message<concreteprotocol::LweSecretKeyInfo> InfoType;
private:
std::shared_ptr<std::vector<uint64_t>> buffer;
Message<concreteprotocol::LweSecretKeyInfo> info;
};
class LweBootstrapKey {
friend class Keyset;
public:
LweBootstrapKey(Message<concreteprotocol::LweBootstrapKeyInfo> info,
const LweSecretKey &inputKey, const LweSecretKey &outputKey,
CSPRNG &csprng);
LweBootstrapKey() = delete;
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
Message<concreteprotocol::LweBootstrapKeyInfo> info)
: buffer(buffer), info(info){};
static LweBootstrapKey
fromProto(const Message<concreteprotocol::LweBootstrapKey> &proto);
Message<concreteprotocol::LweBootstrapKey> toProto() const;
const uint64_t *getRawPtr() const;
size_t getSize() const;
const Message<concreteprotocol::LweBootstrapKeyInfo> &getInfo() const;
const std::vector<uint64_t> &getBuffer() const;
typedef Message<concreteprotocol::LweBootstrapKeyInfo> InfoType;
private:
std::shared_ptr<std::vector<uint64_t>> buffer;
Message<concreteprotocol::LweBootstrapKeyInfo> info;
};
class LweKeyswitchKey {
friend class Keyset;
public:
LweKeyswitchKey(Message<concreteprotocol::LweKeyswitchKeyInfo> info,
const LweSecretKey &inputKey, const LweSecretKey &outputKey,
CSPRNG &csprng);
LweKeyswitchKey() = delete;
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
Message<concreteprotocol::LweKeyswitchKeyInfo> info)
: buffer(buffer), info(info){};
static LweKeyswitchKey
fromProto(const Message<concreteprotocol::LweKeyswitchKey> &proto);
Message<concreteprotocol::LweKeyswitchKey> toProto() const;
const uint64_t *getRawPtr() const;
size_t getSize() const;
const Message<concreteprotocol::LweKeyswitchKeyInfo> &getInfo() const;
const std::vector<uint64_t> &getBuffer() const;
typedef Message<concreteprotocol::LweKeyswitchKeyInfo> InfoType;
private:
std::shared_ptr<std::vector<uint64_t>> buffer;
Message<concreteprotocol::LweKeyswitchKeyInfo> info;
};
class PackingKeyswitchKey {
friend class Keyset;
public:
PackingKeyswitchKey(Message<concreteprotocol::PackingKeyswitchKeyInfo> info,
const LweSecretKey &inputKey,
const LweSecretKey &outputKey, CSPRNG &csprng);
PackingKeyswitchKey() = delete;
PackingKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
Message<concreteprotocol::PackingKeyswitchKeyInfo> info)
: buffer(buffer), info(info){};
static PackingKeyswitchKey
fromProto(const Message<concreteprotocol::PackingKeyswitchKey> &proto);
Message<concreteprotocol::PackingKeyswitchKey> toProto() const;
const uint64_t *getRawPtr() const;
size_t getSize() const;
const Message<concreteprotocol::PackingKeyswitchKeyInfo> &getInfo() const;
const std::vector<uint64_t> &getBuffer() const;
typedef Message<concreteprotocol::PackingKeyswitchKeyInfo> InfoType;
private:
std::shared_ptr<std::vector<uint64_t>> buffer;
Message<concreteprotocol::PackingKeyswitchKeyInfo> info;
};
} // namespace keys
} // namespace concretelang
#endif

View File

@@ -0,0 +1,79 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_KEYSETS_H
#define CONCRETELANG_COMMON_KEYSETS_H
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Keys.h"
#include <functional>
#include <memory>
#include <stdlib.h>
#include <string>
using concretelang::error::Result;
using concretelang::error::StringError;
using concretelang::keys::LweBootstrapKey;
using concretelang::keys::LweKeyswitchKey;
using concretelang::keys::LweSecretKey;
using concretelang::keys::PackingKeyswitchKey;
namespace concretelang {
namespace keysets {
struct ClientKeyset {
std::vector<LweSecretKey> lweSecretKeys;
static ClientKeyset
fromProto(const Message<concreteprotocol::ClientKeyset> &proto);
Message<concreteprotocol::ClientKeyset> toProto() const;
};
struct ServerKeyset {
std::vector<LweBootstrapKey> lweBootstrapKeys;
std::vector<LweKeyswitchKey> lweKeyswitchKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
static ServerKeyset
fromProto(const Message<concreteprotocol::ServerKeyset> &proto);
Message<concreteprotocol::ServerKeyset> toProto() const;
};
struct Keyset {
ServerKeyset server;
ClientKeyset client;
/// Generates a fresh keyset from infos.
Keyset(const Message<concreteprotocol::KeysetInfo> &info, CSPRNG &csprng);
Keyset(ServerKeyset server, ClientKeyset client)
: server(server), client(client) {}
static Keyset fromProto(const Message<concreteprotocol::Keyset> &proto);
Message<concreteprotocol::Keyset> toProto() const;
};
class KeysetCache {
std::string backingDirectoryPath;
public:
KeysetCache(std::string backingDirectoryPath);
Result<Keyset>
getKeyset(const Message<concreteprotocol::KeysetInfo> &keysetInfo,
uint64_t seed_msb, uint64_t seed_lsb);
private:
KeysetCache() = default;
};
} // namespace keysets
} // namespace concretelang
#endif

View File

@@ -0,0 +1,341 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_PROTOCOL_H
#define CONCRETELANG_COMMON_PROTOCOL_H
#include "boost/outcome.h"
#include "capnp/common.h"
#include "capnp/compat/json.h"
#include "capnp/message.h"
#include "capnp/serialize-packed.h"
#include "capnp/serialize.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Error.h"
#include "kj/common.h"
#include "kj/exception.h"
#include "kj/io.h"
#include "kj/std/iostream.h"
#include "kj/string.h"
#include <algorithm>
#include <cstddef>
#include <fstream>
#include <memory>
#include <optional>
#include <sstream>
#include <vector>
using concretelang::error::Result;
using concretelang::error::StringError;
const uint64_t MAX_SEGMENT_SIZE = capnp::MAX_SEGMENT_WORDS;
namespace concretelang {
namespace protocol {
/// Arena carrying capnp messages.
///
/// This type packs a message with an arena used to store the data in a single
/// object.
///
/// Rationale:
/// ----------
///
/// Capnproto is a performance-oriented serialization framework, which
/// approaches the problem by constructing a memory representation that is
/// already equivalent to the serialized binary representation.
///
/// To make that possible and as fast as possible, they use an arena-passing
/// programming model, which makes serialization fast, but is also pretty
/// invasive:
/// + The top-level message being constructed must be known in advanced, so as
/// to properly initialize the MallocMessageBuilder.
/// + The parent message builder must be passed to a function creating a child
/// message, so as to properly initialize the field, and fill it.
/// + The arena must be managed at the top level to ensure that the messages are
/// always pointing to valid memory locations.
///
/// In the compiler, we use the concrete-protocol messages for slightly more
/// than just serialization, and having a less-than-optimal serialization speed
/// is not a big problem for now. For this reason, it makes sense to make the
/// use of the capnp messages a little more ergonomic, which is what this type
/// allows.
template <typename MessageType> struct Message {
Message() : message(nullptr) {
regionBuilder = new capnp::MallocMessageBuilder();
message = regionBuilder->initRoot<MessageType>();
}
Message(const typename MessageType::Reader &reader) : message(nullptr) {
regionBuilder = new capnp::MallocMessageBuilder(
std::min(reader.totalSize().wordCount, MAX_SEGMENT_SIZE),
capnp::AllocationStrategy::FIXED_SIZE);
regionBuilder->setRoot(reader);
message = regionBuilder->getRoot<MessageType>();
}
Message(const Message &input) : message(nullptr) {
regionBuilder = new capnp::MallocMessageBuilder(
std::min(input.message.asReader().totalSize().wordCount,
MAX_SEGMENT_SIZE),
capnp::AllocationStrategy::FIXED_SIZE);
regionBuilder->setRoot(input.message.asReader());
message = regionBuilder->getRoot<MessageType>();
}
Message &operator=(const typename MessageType::Reader &reader) {
if (regionBuilder) {
delete regionBuilder;
}
regionBuilder = new capnp::MallocMessageBuilder(
std::min(reader.totalSize().wordCount, MAX_SEGMENT_SIZE),
capnp::AllocationStrategy::FIXED_SIZE);
regionBuilder->setRoot(reader);
message = regionBuilder->getRoot<MessageType>();
return *this;
}
Message &operator=(const Message &input) {
if (this != &input) {
if (regionBuilder) {
delete regionBuilder;
}
regionBuilder = new capnp::MallocMessageBuilder(
std::min(input.message.asReader().totalSize().wordCount,
MAX_SEGMENT_SIZE),
capnp::AllocationStrategy::FIXED_SIZE);
regionBuilder->setRoot(input.message.asReader());
message = regionBuilder->getRoot<MessageType>();
}
return *this;
}
Message(Message &&input) : message(nullptr) {
regionBuilder = input.regionBuilder;
message = input.message;
input.regionBuilder = nullptr;
}
Message &operator=(Message &&input) {
if (this != &input) {
if (regionBuilder) {
delete regionBuilder;
}
regionBuilder = input.regionBuilder;
message = input.message;
input.regionBuilder = nullptr;
}
return *this;
}
~Message() {
if (regionBuilder) {
delete regionBuilder;
}
}
typename MessageType::Reader asReader() const { return message.asReader(); }
typename MessageType::Builder asBuilder() { return message; }
Result<void> writeBinaryToFd(int fd) const {
try {
capnp::writeMessageToFd(fd, *regionBuilder);
return outcome::success();
} catch (const kj::Exception &e) {
return StringError("Failed to write message to file descriptor: ")
<< e.getDescription().cStr();
}
}
Result<void> writeBinaryToOstream(std::ostream &ostream) const {
try {
kj::std::StdOutputStream kjOstream(ostream);
capnp::writeMessage(kjOstream, *regionBuilder);
} catch (const kj::Exception &e) {
return StringError("Failed to write message to ostream: ")
<< e.getDescription().cStr();
}
ostream.flush();
if (!ostream.good()) {
return StringError(
"Failed to write message to ostream. Ended up in bad state.");
}
return outcome::success();
}
Result<std::string> writeBinaryToString() const {
auto ostream = std::ostringstream();
OUTCOME_TRYV(this->writeBinaryToOstream(ostream));
return outcome::success(ostream.str());
}
Result<std::string> writeJsonToString() const {
try {
capnp::JsonCodec json;
kj::String output = json.encode(this->message.asReader());
return outcome::success(std::string(output.cStr(), output.size()));
} catch (const kj::Exception &e) {
return outcome::failure(
StringError("Failed to write message to json string: ")
<< e.getDescription().cStr());
}
}
Result<void> readBinaryFromFd(int fd) {
try {
capnp::readMessageCopyFromFd(fd, *regionBuilder);
this->message = regionBuilder->getRoot<MessageType>();
return outcome::success();
} catch (const kj::Exception &e) {
return StringError("Failed to read message from file descriptor: ")
<< e.getDescription().cStr();
}
}
Result<void>
readBinaryFromIstream(std::istream &istream,
capnp::ReaderOptions options = capnp::ReaderOptions()) {
try {
kj::std::StdInputStream kjIstream(istream);
capnp::readMessageCopy(kjIstream, *regionBuilder, options);
this->message = regionBuilder->getRoot<MessageType>();
return outcome::success();
} catch (const kj::Exception &e) {
return StringError("Failed to read message from istream: ")
<< e.getDescription().cStr();
}
}
Result<void>
readBinaryFromString(const std::string &input,
capnp::ReaderOptions options = capnp::ReaderOptions()) {
auto istream = std::istringstream(input);
return this->readBinaryFromIstream(istream, options);
}
Result<void> readJsonFromString(const std::string &input) {
try {
capnp::JsonCodec json;
kj::StringPtr stringPointer(input.c_str(), input.size());
this->message = this->regionBuilder->template initRoot<MessageType>();
json.decode(stringPointer, this->message);
return outcome::success();
} catch (const kj::Exception &e) {
return StringError("Failed to read message from json string: ")
<< e.getDescription().cStr();
}
}
std::string debugString() const { return writeJsonToString().value(); }
private:
capnp::MallocMessageBuilder *regionBuilder;
typename MessageType::Builder message;
};
template struct Message<concreteprotocol::ProgramInfo>;
template struct Message<concreteprotocol::CircuitEncodingInfo>;
template struct Message<concreteprotocol::Value>;
template struct Message<concreteprotocol::GateInfo>;
/// Helper function turning a vector of integers to a payload.
template <typename T>
Message<concreteprotocol::Payload>
vectorToProtoPayload(const std::vector<T> &input) {
auto output = Message<concreteprotocol::Payload>();
auto elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
auto remainingElms = input.size() % elmsPerBlob;
auto nbBlobs = (input.size() / elmsPerBlob) + (remainingElms > 0);
auto dataBuilder = output.asBuilder().initData(nbBlobs);
// Process all but the last blob, which store as much as `Data` allow.
if (nbBlobs > 1) {
for (size_t blobIndex = 0; blobIndex < nbBlobs - 1; blobIndex++) {
auto blobPtr = input.data() + blobIndex * elmsPerBlob;
auto blobLen = elmsPerBlob * sizeof(T);
dataBuilder.set(
blobIndex,
capnp::Data::Reader(reinterpret_cast<const unsigned char *>(blobPtr),
blobLen));
}
}
// Process the last blob which store the remainder.
if (nbBlobs > 0) {
auto lastBlobIndex = nbBlobs - 1;
auto lastBlobPtr = input.data() + lastBlobIndex * elmsPerBlob;
auto lastBlobLen = remainingElms * sizeof(T);
dataBuilder.set(
lastBlobIndex,
capnp::Data::Reader(
reinterpret_cast<const unsigned char *>(lastBlobPtr), lastBlobLen));
}
return output;
}
/// Helper function turning a payload to a vector of integers.
template <typename T>
std::vector<T>
protoPayloadToVector(const Message<concreteprotocol::Payload> &input) {
auto payloadData = input.asReader().getData();
auto elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
auto totalPayloadSize = 0;
for (auto blob : payloadData) {
totalPayloadSize += blob.size();
}
assert(totalPayloadSize % sizeof(T) == 0);
auto dataSize = totalPayloadSize / sizeof(T);
auto output = std::vector<T>();
output.resize(dataSize);
for (size_t blobIndex = 0; blobIndex < payloadData.size(); blobIndex++) {
auto blobData = payloadData[blobIndex];
auto blobPtr = output.data() + blobIndex * elmsPerBlob;
std::memcpy(blobPtr, blobData.begin(), blobData.size());
}
return output;
}
/// Helper function turning a payload to a shared vector of integers on the
/// heap.
template <typename T>
std::shared_ptr<std::vector<T>>
protoPayloadToSharedVector(const Message<concreteprotocol::Payload> &input) {
auto payloadData = input.asReader().getData();
size_t elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
size_t totalPayloadSize = 0;
for (auto blob : payloadData) {
totalPayloadSize += blob.size();
}
assert(totalPayloadSize % sizeof(T) == 0);
size_t dataSize = totalPayloadSize / sizeof(T);
auto output = std::make_shared<std::vector<T>>();
output->resize(dataSize);
for (size_t blobIndex = 0; blobIndex < payloadData.size(); blobIndex++) {
auto blobData = payloadData[blobIndex];
auto blobPtr = output->data() + blobIndex * elmsPerBlob;
std::memcpy(blobPtr, blobData.begin(), blobData.size());
}
return output;
}
/// Helper function turning a protocol `Shape` object into a vector of
/// dimensions.
std::vector<size_t>
protoShapeToDimensions(const Message<concreteprotocol::Shape> &shape);
/// Helper function turning a protocol `Shape` object into a vector of
/// dimensions.
Message<concreteprotocol::Shape>
dimensionsToProtoShape(const std::vector<size_t> &input);
template <typename MessageType> size_t hashMessage(Message<MessageType> &mess);
} // namespace protocol
} // namespace concretelang
#endif

View File

@@ -0,0 +1,90 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_TRANSFORMERS_H
#define CONCRETELANG_COMMON_TRANSFORMERS_H
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Keysets.h"
#include "concretelang/Common/Values.h"
#include <memory>
#include <stdlib.h>
using concretelang::error::Result;
using concretelang::keysets::ClientKeyset;
using concretelang::values::Tensor;
using concretelang::values::TransportValue;
using concretelang::values::Value;
namespace concretelang {
namespace transformers {
/// A type for input transformers, that is, functions running on the client
/// side, that prepare a Value to be sent to the server as a TransportValue.
typedef std::function<Result<TransportValue>(Value)> InputTransformer;
/// A type for output transformers, that is, functions running on the client
/// side, that process a TransportValue fetched from the server to be used as a
/// Value.
typedef std::function<Result<Value>(TransportValue)> OutputTransformer;
/// A type for arguments transformers, that is, functions running on the server
/// side, that transform a TransportValue fetched from the client, to be used as
/// argument in a circuit call.
typedef std::function<Result<Value>(TransportValue)> ArgTransformer;
/// A type for return transformers, that is, functions running on the server
/// side, that transform a value returned from circuit call into a
/// TransportValue to be sent to the client.
typedef std::function<Result<TransportValue>(Value)> ReturnTransformer;
/// A factory static class that generates transformers.
class TransformerFactory {
public:
static Result<InputTransformer>
getIndexInputTransformer(Message<concreteprotocol::GateInfo> gateInfo);
static Result<OutputTransformer>
getIndexOutputTransformer(Message<concreteprotocol::GateInfo> gateInfo);
static Result<ArgTransformer>
getIndexArgTransformer(Message<concreteprotocol::GateInfo> gateInfo);
static Result<ReturnTransformer>
getIndexReturnTransformer(Message<concreteprotocol::GateInfo> gateInfo);
static Result<InputTransformer>
getPlaintextInputTransformer(Message<concreteprotocol::GateInfo> gateInfo);
static Result<OutputTransformer>
getPlaintextOutputTransformer(Message<concreteprotocol::GateInfo> gateInfo);
static Result<ArgTransformer>
getPlaintextArgTransformer(Message<concreteprotocol::GateInfo> gateInfo);
static Result<ReturnTransformer>
getPlaintextReturnTransformer(Message<concreteprotocol::GateInfo> gateInfo);
static Result<InputTransformer> getLweCiphertextInputTransformer(
ClientKeyset keyset, Message<concreteprotocol::GateInfo> gateInfo,
std::shared_ptr<CSPRNG> csprng, bool useSimulation);
static Result<OutputTransformer> getLweCiphertextOutputTransformer(
ClientKeyset keyset, Message<concreteprotocol::GateInfo> gateInfo,
bool useSimulation);
static Result<ArgTransformer>
getLweCiphertextArgTransformer(Message<concreteprotocol::GateInfo> gateInfo,
bool useSimulation);
static Result<ReturnTransformer> getLweCiphertextReturnTransformer(
Message<concreteprotocol::GateInfo> gateInfo, bool useSimulation);
};
} // namespace transformers
} // namespace concretelang
#endif

View File

@@ -0,0 +1,217 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_COMMON_VALUES_H
#define CONCRETELANG_COMMON_VALUES_H
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Protocol.h"
#include <cstddef>
#include <cstdint>
#include <initializer_list>
#include <optional>
#include <stdlib.h>
#include <variant>
using concretelang::error::Result;
using concretelang::error::StringError;
using concretelang::protocol::dimensionsToProtoShape;
using concretelang::protocol::Message;
using concretelang::protocol::protoPayloadToVector;
using concretelang::protocol::protoShapeToDimensions;
using concretelang::protocol::vectorToProtoPayload;
namespace concretelang {
namespace values {
/// A type for public (encrypted or not) values, that can be safely transported
/// between client and server to for execution.
typedef Message<concreteprotocol::Value> TransportValue;
/// A type for tensor data.
template <typename T> struct Tensor {
std::vector<T> values;
std::vector<size_t> dimensions;
Tensor<T>() = default;
Tensor<T>(std::vector<T> values, std::vector<size_t> dimensions)
: values(values), dimensions(dimensions) {}
/// Creates an tensor with the shape described by the input dimensions, filled
/// with zeros.
static Tensor<T> fromDimensions(std::vector<size_t> &dimensions) {
uint32_t length = 1;
for (auto dim : dimensions) {
length *= dim;
}
auto values = std::vector<T>(length);
for (auto &val : values) {
*val = 0;
}
return Tensor{values, dimensions};
}
/// Conversion constructor from a scalar value.
Tensor<T>(T in) { this->values.push_back(in); }
/// Constructor from initializer lists of values and dimensions.
Tensor<T>(std::initializer_list<T> values,
std::initializer_list<size_t> dimensions) {
size_t count = 1;
for (auto dim : dimensions) {
count *= dim;
}
assert(values.size() == count);
for (auto val : values) {
this->values.push_back(val);
}
for (auto dim : dimensions) {
this->dimensions.push_back(dim);
}
}
bool operator==(const Tensor<T> &b) const {
return this->values == b.values && this->dimensions == b.dimensions;
}
Tensor<T> operator-(T b) const {
Tensor<T> out = *this;
for (size_t i = 0; i < out.values.size(); i++) {
out.values[i] -= b;
}
return out;
}
Tensor<T> operator-(Tensor<T> b) const {
assert(this->dimensions == b.dimensions);
Tensor<T> out = *this;
for (size_t i = 0; i < out.values.size(); i++) {
out.values[i] -= b.values[i];
}
return out;
}
Tensor<T> operator+(T b) const {
Tensor<T> out = *this;
for (size_t i = 0; i < out.values.size(); i++) {
out.values[i] += b;
}
return out;
}
Tensor<T> operator+(Tensor<T> b) const {
assert(this->dimensions == b.dimensions);
Tensor<T> out = *this;
for (size_t i = 0; i < out.values.size(); i++) {
out.values[i] += b.values[i];
}
return out;
}
Tensor<T> operator*(T b) const {
Tensor<T> out = *this;
for (size_t i = 0; i < out.values.size(); i++) {
out.values[i] *= b;
}
return out;
}
Tensor<T> operator*(Tensor<T> b) const {
assert(this->dimensions == b.dimensions);
Tensor<T> out = *this;
for (size_t i = 0; i < out.values.size(); i++) {
out.values[i] *= b.values[i];
}
return out;
}
T &operator[](int index) { return this->values[index]; }
template <typename U> explicit operator Tensor<U>() const {
Tensor<U> output;
output.dimensions = this->dimensions;
for (auto v : this->values) {
output.values.push_back((U)v);
}
return output;
}
bool isScalar() const { return dimensions.empty(); }
};
/// A type for tensor data of varying precisions. Mainly use to manipulate
struct Value {
friend class ClientCircuit;
std::variant<Tensor<uint8_t>, Tensor<int8_t>, Tensor<uint16_t>,
Tensor<int16_t>, Tensor<uint32_t>, Tensor<int32_t>,
Tensor<uint64_t>, Tensor<int64_t>>
inner;
Value() = default;
Value(Tensor<uint8_t> inner) : inner(inner){};
Value(Tensor<uint16_t> inner) : inner(inner){};
Value(Tensor<uint32_t> inner) : inner(inner){};
Value(Tensor<uint64_t> inner) : inner(inner){};
Value(Tensor<int8_t> inner) : inner(inner){};
Value(Tensor<int16_t> inner) : inner(inner){};
Value(Tensor<int32_t> inner) : inner(inner){};
Value(Tensor<int64_t> inner) : inner(inner){};
/// Turns a server value to a client value, without interpreting the kind of
/// value.
static Value fromRawTransportValue(TransportValue transportVal);
/// Turns a client value to a raw (without kind info attached) server value.
TransportValue intoRawTransportValue() const;
bool operator==(const Value &b) const;
uint32_t getIntegerPrecision() const;
bool isSigned() const;
Message<concreteprotocol::Payload> intoProtoPayload() const;
Message<concreteprotocol::Shape> intoProtoShape() const;
std::vector<size_t> getDimensions() const;
size_t getLength() const;
template <typename T> bool hasElementType() const {
return std::holds_alternative<Tensor<T>>(inner);
}
template <typename T> std::optional<Tensor<T>> getTensor() const {
if (!hasElementType<T>()) {
return std::nullopt;
}
return std::get<Tensor<T>>(inner);
}
template <typename T> Tensor<T> *getTensorPtr() {
if (!hasElementType<T>()) {
return nullptr;
}
return &std::get<Tensor<T>>(inner);
}
bool
isCompatibleWithShape(const Message<concreteprotocol::Shape> &shape) const;
bool isScalar() const;
Value toUnsigned() const;
Value toSigned() const;
};
size_t getCorrespondingPrecision(size_t originalPrecision);
} // namespace values
} // namespace concretelang
#endif

View File

@@ -59,14 +59,14 @@ def TFHE_PackingKeyswitchKeyAttr: TFHE_Attr<"GLWEPackingKeyswitchKey", "pksk"> {
"mlir::concretelang::TFHE::GLWESecretKey":$inputKey,
"mlir::concretelang::TFHE::GLWESecretKey":$outputKey,
"int" : $outputPolySize,
"int" : $inputLweDim,
"int" : $innerLweDim,
"int" : $glweDim,
"int" : $levels,
"int" : $baseLog,
DefaultValuedParameter<"int", "-1">: $index
);
let assemblyFormat = " (`[` $index^ `]` )? `<` $inputKey `,` $outputKey`,` $outputPolySize`,` $inputLweDim `,` $glweDim `,` $levels `,` $baseLog `>`";
let assemblyFormat = " (`[` $index^ `]` )? `<` $inputKey `,` $outputKey`,` $outputPolySize`,` $innerLweDim `,` $glweDim `,` $levels `,` $baseLog `>`";
}

View File

@@ -6,16 +6,16 @@
#ifndef CONCRETELANG_RUNTIME_CONTEXT_H
#define CONCRETELANG_RUNTIME_CONTEXT_H
#include "concrete-cpu.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Keysets.h"
#include <assert.h>
#include <map>
#include <mutex>
#include <pthread.h>
#include <vector>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Common/Error.h"
#include "concrete-cpu.h"
using ::concretelang::keysets::ServerKeyset;
#ifdef CONCRETELANG_CUDA_SUPPORT
#include "bootstrap.h"
@@ -40,7 +40,7 @@ typedef struct FFT {
typedef struct RuntimeContext {
RuntimeContext() = delete;
RuntimeContext(::concretelang::clientlib::EvaluationKeys evaluationKeys);
RuntimeContext(ServerKeyset serverKeyset);
~RuntimeContext() {
#ifdef CONCRETELANG_CUDA_SUPPORT
for (int i = 0; i < num_devices; ++i) {
@@ -53,7 +53,7 @@ typedef struct RuntimeContext {
};
const uint64_t *keyswitch_key_buffer(size_t keyId) {
return evaluationKeys.getKeyswitchKey(keyId).buffer();
return serverKeyset.lweKeyswitchKeys[keyId].getRawPtr();
}
const double *fourier_bootstrap_key_buffer(size_t keyId) {
@@ -61,17 +61,15 @@ typedef struct RuntimeContext {
}
const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
return evaluationKeys.getPackingKeyswitchKey(keyId).buffer();
return serverKeyset.packingKeyswitchKeys[keyId].getRawPtr();
}
const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }
const ::concretelang::clientlib::EvaluationKeys getKeys() const {
return evaluationKeys;
}
const ServerKeyset getKeys() const { return serverKeyset; }
private:
::concretelang::clientlib::EvaluationKeys evaluationKeys;
ServerKeyset serverKeyset;
std::vector<std::shared_ptr<std::vector<double>>> fourier_bootstrap_keys;
std::vector<FFT> ffts;
@@ -89,15 +87,15 @@ public:
return bsk_gpu[gpu_idx];
}
auto bsk = evaluationKeys.getBootstrapKey(0);
auto bsk = serverKeyset.lweBootstrapKeys[0];
size_t bsk_buffer_len = bsk.size();
size_t bsk_buffer_len = bsk.getBuffer().size();
size_t bsk_gpu_buffer_size = bsk_buffer_len * sizeof(double);
void *bsk_gpu_tmp =
cuda_malloc_async(bsk_gpu_buffer_size, (cudaStream_t *)stream, gpu_idx);
cuda_convert_lwe_bootstrap_key_64(
bsk_gpu_tmp, const_cast<uint64_t *>(bsk.buffer()),
bsk_gpu_tmp, const_cast<uint64_t *>(bsk.getBuffer().data()),
(cudaStream_t *)stream, gpu_idx, input_lwe_dim, glwe_dim, level,
poly_size);
// Synchronization here is not optional as it works with mutex to
@@ -118,14 +116,15 @@ public:
if (ksk_gpu[gpu_idx] != nullptr) {
return ksk_gpu[gpu_idx];
}
auto ksk = evaluationKeys.getKeyswitchKey(0);
auto ksk = serverKeyset.lweKeyswitchKeys[0];
size_t ksk_buffer_size = sizeof(uint64_t) * ksk.size();
size_t ksk_buffer_size = sizeof(uint64_t) * ksk.getBuffer().size();
void *ksk_gpu_tmp =
cuda_malloc_async(ksk_buffer_size, (cudaStream_t *)stream, gpu_idx);
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, const_cast<uint64_t *>(ksk.buffer()),
cuda_memcpy_async_to_gpu(ksk_gpu_tmp,
const_cast<uint64_t *>(ksk.getBuffer().data()),
ksk_buffer_size, (cudaStream_t *)stream, gpu_idx);
// Synchronization here is not optional as it works with mutex to
// prevent other GPU streams from reading partially copied keys.

View File

@@ -29,7 +29,6 @@
#include <mlir/ExecutionEngine/CRunnerUtils.h>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/Runtime/dfr_debug_interface.h"

View File

@@ -15,27 +15,30 @@
#include <hpx/modules/collectives.hpp>
#include <hpx/modules/serialization.hpp>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Keys.h"
#include "concretelang/Common/Keysets.h"
using concretelang::keys::LweBootstrapKey;
using concretelang::keys::LweKeyswitchKey;
using concretelang::keys::LweSecretKey;
using concretelang::keys::PackingKeyswitchKey;
using concretelang::keysets::ServerKeyset;
namespace mlir {
namespace concretelang {
namespace dfr {
using namespace ::concretelang::clientlib;
struct RuntimeContextManager;
namespace {
static void *dl_handle;
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
template <typename LweKeyType, typename KeyParamType> struct KeyWrapper {
template <typename LweKeyType> struct KeyWrapper {
std::vector<LweKeyType> keys;
KeyWrapper() {}
@@ -51,38 +54,44 @@ template <typename LweKeyType, typename KeyParamType> struct KeyWrapper {
void save(Archive &ar, const unsigned int version) const {
ar << (size_t)keys.size();
for (auto k : keys) {
auto params = k.parameters();
size_t param_size = sizeof(KeyParamType);
ar << hpx::serialization::make_array((char *)&params, param_size);
ar << (size_t)k.size();
ar << hpx::serialization::make_array(k.buffer(), k.size());
auto info = k.getInfo();
auto maybe_info_string = info.writeBinaryToString();
assert(maybe_info_string.has_value());
auto info_string = maybe_info_string.value();
ar << hpx::serialization::make_array(info_string.c_str(),
info_string.size());
ar << (size_t)k.getBuffer().size();
ar << hpx::serialization::make_array(k.getBuffer().data(),
k.getBuffer().size());
}
}
template <class Archive> void load(Archive &ar, const unsigned int version) {
size_t num_keys;
ar >> num_keys;
for (uint i = 0; i < num_keys; ++i) {
KeyParamType params;
size_t param_size = sizeof(params);
ar >> hpx::serialization::make_array((char *)&params, param_size);
std::string info_string;
ar >> info_string;
typename LweKeyType::InfoType info;
assert(info.readBinaryFromString(info_string).has_value());
size_t key_size;
ar >> key_size;
auto buffer = std::make_shared<std::vector<uint64_t>>();
buffer->resize(key_size);
ar >> hpx::serialization::make_array(buffer->data(), key_size);
keys.push_back(LweKeyType(buffer, params));
keys.push_back(LweKeyType(buffer, info));
}
}
HPX_SERIALIZATION_SPLIT_MEMBER()
};
template <typename LweKeyType, typename KeyParamType>
bool operator==(const KeyWrapper<LweKeyType, KeyParamType> &lhs,
const KeyWrapper<LweKeyType, KeyParamType> &rhs) {
template <typename LweKeyType>
bool operator==(const KeyWrapper<LweKeyType> &lhs,
const KeyWrapper<LweKeyType> &rhs) {
if (lhs.keys.size() != rhs.keys.size())
return false;
for (size_t i = 0; i < lhs.keys.size(); ++i)
if (lhs.keys[i].buffer() != rhs.keys[i].buffer())
if (lhs.keys[i].getBuffer() != rhs.keys[i].getBuffer())
return false;
return true;
}
@@ -110,21 +119,21 @@ struct RuntimeContextManager {
if (_dfr_is_root_node()) {
RuntimeContext *context = (RuntimeContext *)ctx;
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam> kskw(
context->getKeys().getKeyswitchKeys());
KeyWrapper<LweBootstrapKey, BootstrapKeyParam> bskw(
context->getKeys().getBootstrapKeys());
KeyWrapper<LweKeyswitchKey> kskw(context->getKeys().lweKeyswitchKeys);
KeyWrapper<LweBootstrapKey> bskw(context->getKeys().lweBootstrapKeys);
hpx::collectives::broadcast_to("ksk_keystore", kskw);
hpx::collectives::broadcast_to("bsk_keystore", bskw);
} else {
auto kskFut = hpx::collectives::broadcast_from<
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam>>("ksk_keystore");
auto bskFut = hpx::collectives::broadcast_from<
KeyWrapper<LweBootstrapKey, BootstrapKeyParam>>("bsk_keystore");
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam> kskw = kskFut.get();
KeyWrapper<LweBootstrapKey, BootstrapKeyParam> bskw = bskFut.get();
auto kskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey>>(
"ksk_keystore");
auto bskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey>>(
"bsk_keystore");
KeyWrapper<LweKeyswitchKey> kskw = kskFut.get();
KeyWrapper<LweBootstrapKey> bskw = bskFut.get();
context = new mlir::concretelang::RuntimeContext(
EvaluationKeys(kskw.keys, bskw.keys, {}));
ServerKeyset{bskw.keys, kskw.keys, {}});
}
}

View File

@@ -17,6 +17,13 @@ typedef enum stream_type {
TS_STREAM_TYPE_TOPO_TO_X86_LSAP,
TS_STREAM_TYPE_X86_TO_X86_LSAP
} stream_type;
template <size_t N> struct MemRefDescriptor {
uint64_t *allocated;
uint64_t *aligned;
size_t offset;
size_t sizes[N];
size_t strides[N];
};
extern "C" {
void *stream_emulator_init();

View File

@@ -1,42 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
#define CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Common/Error.h"
namespace concretelang {
namespace serverlib {
using concretelang::clientlib::ClientParameters;
using concretelang::error::StringError;
class DynamicModule {
public:
~DynamicModule();
static outcome::checked<std::shared_ptr<DynamicModule>, StringError>
open(std::string outputPath);
private:
outcome::checked<void, StringError>
loadClientParametersJSON(std::string outputPath);
outcome::checked<void, StringError> loadSharedLibrary(std::string outputPath);
private:
std::vector<ClientParameters> clientParametersList;
void *libraryHandle;
friend class ServerLambda;
};
} // namespace serverlib
} // namespace concretelang
#endif

View File

@@ -1,63 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
#define CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
#include <cassert>
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
#include "concretelang/ServerLib/DynamicModule.h"
#include "concretelang/Support/Error.h"
namespace concretelang {
namespace serverlib {
using concretelang::clientlib::ScalarOrTensorData;
/// ServerLambda is a utility class that allows to call a function of a
/// compilation result.
class ServerLambda {
public:
/// Load the symbol `funcName` from the shared lib in the artifacts folder
/// located in `outputPath`
static outcome::checked<ServerLambda, concretelang::error::StringError>
load(std::string funcName, std::string outputPath);
/// Load the symbol `funcName` of the dynamic loaded library
static outcome::checked<ServerLambda, concretelang::error::StringError>
loadFromModule(std::shared_ptr<DynamicModule> module, std::string funcName);
/// Call the ServerLambda with public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
call(clientlib::PublicArguments &args,
std::optional<clientlib::EvaluationKeys> evaluationKeys,
bool simulation = false);
/// \brief Call the loaded function using opaque pointers to both inputs and
/// outputs.
/// \param args Array containing pointers to inputs first, followed by
/// pointers to outputs.
/// \return Error if failed, success otherwise.
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
protected:
ClientParameters clientParameters;
/// holds a pointer to the entrypoint of the shared lib which
void (*func)(void *...);
/// Retain module and open shared lib alive
std::shared_ptr<DynamicModule> module;
};
} // namespace serverlib
} // namespace concretelang
#endif

View File

@@ -0,0 +1,102 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
#define CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
#include "boost/outcome.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Keysets.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Common/Transformers.h"
#include "concretelang/Common/Values.h"
#include "llvm/ADT/ArrayRef.h"
#include <cassert>
#include <dlfcn.h>
#include <functional>
#include <memory>
#include <vector>
using concretelang::keysets::ServerKeyset;
using concretelang::transformers::ArgTransformer;
using concretelang::transformers::ReturnTransformer;
using concretelang::transformers::TransformerFactory;
using concretelang::values::Value;
namespace concretelang {
namespace serverlib {
/// A smart pointer to a dynamic module.
class DynamicModule {
friend class ServerCircuit;
public:
~DynamicModule();
static Result<std::shared_ptr<DynamicModule>>
open(const std::string &outputPath);
private:
void *libraryHandle;
};
class ServerCircuit {
friend class ServerProgram;
public:
/// Call the circuit with public arguments.
Result<std::vector<TransportValue>> call(const ServerKeyset &serverKeyset,
std::vector<TransportValue> &args);
Result<std::vector<TransportValue>>
simulate(std::vector<TransportValue> &args);
/// Returns the name of this circuit.
std::string getName();
private:
ServerCircuit() = default;
static Result<ServerCircuit>
fromDynamicModule(const Message<concreteprotocol::CircuitInfo> &circuitInfo,
std::shared_ptr<DynamicModule> dynamicModule,
bool useSimulation);
void invoke(const ServerKeyset &serverKeyset);
Message<concreteprotocol::CircuitInfo> circuitInfo;
bool useSimulation;
void (*func)(void *...);
std::shared_ptr<DynamicModule> dynamicModule;
std::vector<ArgTransformer> argTransformers;
std::vector<ReturnTransformer> returnTransformers;
std::vector<Value> argsBuffer;
std::vector<Value> returnsBuffer;
std::vector<size_t> argDescriptorSizes;
std::vector<size_t> returnDescriptorSizes;
size_t argRawSize;
size_t returnRawSize;
};
/// ServerProgram contains multiple
class ServerProgram {
public:
/// Loads a server program from a shared lib path essentially.
static Result<ServerProgram>
load(const Message<concreteprotocol::ProgramInfo> &programInfo,
const std::string &outputPath, bool useSimulation);
Result<ServerCircuit> getServerCircuit(const std::string &circuitName);
private:
ServerProgram() = default;
std::vector<ServerCircuit> serverCircuits;
};
} // namespace serverlib
} // namespace concretelang
#endif

View File

@@ -1,31 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
#define CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/IR/BuiltinOps.h>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/Support/V0Parameters.h"
namespace mlir {
namespace concretelang {
using ::concretelang::clientlib::ChunkInfo;
using ::concretelang::clientlib::ClientParameters;
llvm::Expected<ClientParameters>
createClientParametersFromTFHE(mlir::ModuleOp module,
llvm::StringRef functionName, int bitsOfSecurity,
encodings::CircuitEncodings encodings,
std::optional<CRTDecomposition> maybeCrt);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -9,9 +9,15 @@
#include <cstddef>
#include <vector>
#include "concretelang/ClientLib/ClientParameters.h"
#include "boost/outcome.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Protocol.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/JSON.h"
namespace protocol = concreteprotocol;
using concretelang::protocol::Message;
namespace mlir {
namespace concretelang {
@@ -75,9 +81,8 @@ struct CompilationFeedback {
/// @brief memory usage per location
std::map<std::string, int64_t> memoryUsagePerLoc;
/// Fill the sizes from the client parameters.
void
fillFromClientParameters(::concretelang::clientlib::ClientParameters params);
/// Fill the sizes from the program info.
void fillFromProgramInfo(const Message<protocol::ProgramInfo> &params);
/// Load the compilation feedback from a path
static outcome::checked<CompilationFeedback, StringError>
@@ -85,6 +90,7 @@ struct CompilationFeedback {
};
llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &);
bool fromJSON(const llvm::json::Value,
mlir::concretelang::CompilationFeedback &, llvm::json::Path);

View File

@@ -6,15 +6,22 @@
#ifndef CONCRETELANG_SUPPORT_COMPILER_ENGINE_H
#define CONCRETELANG_SUPPORT_COMPILER_ENGINE_H
#include <concretelang/Conversion/Utils/GlobalFHEContext.h>
#include <concretelang/Support/ClientParametersGeneration.h>
#include <concretelang/Support/Encodings.h>
#include <llvm/IR/Module.h>
#include <llvm/Support/Error.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/Pass/Pass.h>
#include "capnp/message.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/Support/ProgramInfoGeneration.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/SourceMgr.h"
#include <memory>
#include <optional>
using concretelang::protocol::Message;
namespace mlir {
namespace concretelang {
@@ -71,7 +78,7 @@ struct CompilationOptions {
bool emitGPUOps;
std::optional<std::vector<int64_t>> fhelinalgTileSizes;
std::optional<std::string> clientParametersFuncName;
std::optional<std::string> mainFuncName;
optimizer::Config optimizerConfig;
@@ -84,7 +91,7 @@ struct CompilationOptions {
/// When compiling from a dialect lower than FHE, one needs to provide
/// encodings info manually to allow the client lib to be generated.
std::optional<mlir::concretelang::encodings::CircuitEncodings> encodings;
std::optional<Message<concreteprotocol::CircuitEncodingInfo>> encodings;
CompilationOptions()
: v0FHEConstraints(std::nullopt), verifyDiagnostics(false),
@@ -92,12 +99,12 @@ struct CompilationOptions {
maxBatchSize(std::numeric_limits<int64_t>::max()), emitSDFGOps(false),
unrollLoopsWithSDFGConvertibleOps(false), dataflowParallelize(false),
optimizeTFHE(true), simulate(false), emitGPUOps(false),
clientParametersFuncName(std::nullopt),
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
chunkSize(4), chunkWidth(2), encodings(std::nullopt){};
mainFuncName(std::nullopt), optimizerConfig(optimizer::DEFAULT_CONFIG),
chunkIntegers(false), chunkSize(4), chunkWidth(2),
encodings(std::nullopt){};
CompilationOptions(std::string funcname) : CompilationOptions() {
clientParametersFuncName = funcname;
mainFuncName = funcname;
}
/// @brief Constructor for CompilationOptions with default parameters for a
@@ -130,7 +137,7 @@ public:
: compilationContext(compilationContext) {}
std::optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
std::optional<mlir::concretelang::ClientParameters> clientParameters;
std::optional<Message<concreteprotocol::ProgramInfo>> programInfo;
std::optional<CompilationFeedback> feedback;
std::unique_ptr<llvm::Module> llvmModule;
std::optional<mlir::concretelang::V0FHEContext> fheContext;
@@ -142,12 +149,11 @@ public:
class Library {
std::string outputDirPath;
std::vector<std::string> objectsPath;
std::vector<mlir::concretelang::ClientParameters> clientParametersList;
std::vector<mlir::concretelang::CompilationFeedback>
compilationFeedbackList;
/// Path to the runtime library. Will be linked to the output library if set
std::string runtimeLibraryPath;
bool cleanUp;
mlir::concretelang::CompilationFeedback compilationFeedback;
Message<concreteprotocol::ProgramInfo> programInfo;
public:
/// Create a library instance on which you can add compilation results.
@@ -156,26 +162,32 @@ public:
Library(std::string outputDirPath, std::string runtimeLibraryPath = "",
bool cleanUp = true)
: outputDirPath(outputDirPath), runtimeLibraryPath(runtimeLibraryPath),
cleanUp(cleanUp) {}
/// Add a compilation result to the library
llvm::Expected<std::string> addCompilation(CompilationResult &compilation);
cleanUp(cleanUp), programInfo() {}
/// Sets the compilation result used by the library
llvm::Expected<std::string>
setCompilationResult(CompilationResult &compilation);
/// Emit the library artifacts with the previously added compilation result
llvm::Error emitArtifacts(bool sharedLib, bool staticLib,
bool clientParameters, bool compilationFeedback,
bool cppHeader);
bool clientParameters, bool compilationFeedback);
/// After a shared library has been emitted, its path is here
std::string sharedLibraryPath;
/// After a static library has been emitted, its path is here
std::string staticLibraryPath;
/// Returns the program info of the library.
Message<concreteprotocol::ProgramInfo> getProgramInfo() const;
/// Returns the path to the output dir.
const std::string &getOutputDirPath() const;
/// Returns the path of the shared library
static std::string getSharedLibraryPath(std::string outputDirPath);
/// Returns the path of the static library
static std::string getStaticLibraryPath(std::string outputDirPath);
/// Returns the path of the client parameters
static std::string getClientParametersPath(std::string outputDirPath);
/// Returns the path of the program info
static std::string getProgramInfoPath(std::string outputDirPath);
/// Returns the path of the compilation feedback
static std::string getCompilationFeedbackPath(std::string outputDirPath);
@@ -194,12 +206,10 @@ public:
llvm::Expected<std::string> emitStatic();
/// Emit a shared library with the previously added compilation result
llvm::Expected<std::string> emitShared();
/// Emit a json ClientParameters corresponding to library content
llvm::Expected<std::string> emitClientParametersJSON();
/// Emit a json ProgramInfo corresponding to library content
llvm::Expected<std::string> emitProgramInfoJSON();
/// Emit a json CompilationFeedback corresponding to library content
llvm::Expected<std::string> emitCompilationFeedbackJSON();
/// Emit a client header file for this corresponding to library content
llvm::Expected<std::string> emitCppHeader();
};
/// Specification of the exit stage of the compilation pipeline
@@ -265,8 +275,7 @@ public:
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
: overrideMaxEintPrecision(), overrideMaxMANP(), compilerOptions(),
generateClientParameters(
compilerOptions.clientParametersFuncName.has_value()),
generateProgramInfo(compilerOptions.mainFuncName.has_value()),
enablePass([](mlir::Pass *pass) { return true; }),
compilationContext(compilationContext) {}
@@ -290,8 +299,7 @@ public:
compile(std::vector<std::string> inputs, std::string outputDirPath,
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
bool generateStaticLib = true, bool generateClientParameters = true,
bool generateCompilationFeedback = true,
bool generateCppHeader = true);
bool generateCompilationFeedback = true);
/// Compile and emit artifact to the given outputDirPath from an LLVM source
/// manager.
@@ -299,38 +307,38 @@ public:
compile(llvm::SourceMgr &sm, std::string outputDirPath,
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
bool generateStaticLib = true, bool generateClientParameters = true,
bool generateCompilationFeedback = true,
bool generateCppHeader = true);
bool generateCompilationFeedback = true);
llvm::Expected<CompilerEngine::Library>
compile(mlir::ModuleOp module, std::string outputDirPath,
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
bool generateStaticLib = true, bool generateClientParameters = true,
bool generateCompilationFeedback = true,
bool generateCppHeader = true);
bool generateCompilationFeedback = true);
void setCompilationOptions(CompilationOptions &options) {
compilerOptions = options;
if (options.v0FHEConstraints.has_value()) {
setFHEConstraints(*options.v0FHEConstraints);
void setCompilationOptions(CompilationOptions options) {
compilerOptions = std::move(options);
if (compilerOptions.v0FHEConstraints.has_value()) {
setFHEConstraints(*compilerOptions.v0FHEConstraints);
}
if (options.clientParametersFuncName.has_value()) {
setGenerateClientParameters(true);
if (compilerOptions.mainFuncName.has_value()) {
setGenerateProgramInfo(true);
}
}
CompilationOptions &getCompilationOptions() { return compilerOptions; }
void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c);
void setMaxEintPrecision(size_t v);
void setMaxMANP(size_t v);
void setGenerateClientParameters(bool v);
void setGenerateProgramInfo(bool v);
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
protected:
std::optional<size_t> overrideMaxEintPrecision;
std::optional<size_t> overrideMaxMANP;
CompilationOptions compilerOptions;
bool generateClientParameters;
bool generateProgramInfo;
std::function<bool(mlir::Pass *)> enablePass;
std::shared_ptr<CompilationContext> compilationContext;

View File

@@ -7,6 +7,7 @@
#define CONCRETELANG_SUPPORT_ENCODINGS_H_
#include <map>
#include <memory>
#include <optional>
#include <string>
#include <vector>
@@ -21,118 +22,28 @@
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include "concretelang/ClientLib/ClientParameters.h"
#include "capnp/message.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
using concretelang::protocol::Message;
namespace mlir {
namespace concretelang {
namespace encodings {
/// Represents the encoding of a small (unchunked) `FHE::eint` type.
struct EncryptedIntegerScalarEncoding {
uint64_t width;
bool isSigned;
};
bool fromJSON(const llvm::json::Value, EncryptedIntegerScalarEncoding &,
llvm::json::Path);
llvm::json::Value toJSON(const EncryptedIntegerScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
EncryptedIntegerScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
llvm::Expected<Message<concreteprotocol::CircuitEncodingInfo>>
getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module);
/// Represents the encoding of a big (chunked) `FHE::eint` type.
struct EncryptedChunkedIntegerScalarEncoding {
uint64_t width;
bool isSigned;
uint64_t chunkSize;
uint64_t chunkWidth;
};
bool fromJSON(const llvm::json::Value, EncryptedChunkedIntegerScalarEncoding &,
llvm::json::Path);
llvm::json::Value toJSON(const EncryptedChunkedIntegerScalarEncoding &);
static inline llvm::raw_ostream &
operator<<(llvm::raw_ostream &OS, EncryptedChunkedIntegerScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a `FHE::ebool` type.
struct EncryptedBoolScalarEncoding {};
bool fromJSON(const llvm::json::Value, EncryptedBoolScalarEncoding &,
llvm::json::Path);
llvm::json::Value toJSON(const EncryptedBoolScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
EncryptedBoolScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a builtin integer type.
struct PlaintextScalarEncoding {
uint64_t width;
};
bool fromJSON(const llvm::json::Value, PlaintextScalarEncoding &,
llvm::json::Path);
llvm::json::Value toJSON(const PlaintextScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
PlaintextScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a builtin index type.
struct IndexScalarEncoding {};
bool fromJSON(const llvm::json::Value, IndexScalarEncoding &, llvm::json::Path);
llvm::json::Value toJSON(const IndexScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
IndexScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a scalar value.
using ScalarEncoding = std::variant<
EncryptedIntegerScalarEncoding, EncryptedChunkedIntegerScalarEncoding,
EncryptedBoolScalarEncoding, PlaintextScalarEncoding, IndexScalarEncoding>;
bool fromJSON(const llvm::json::Value, ScalarEncoding &, llvm::json::Path);
llvm::json::Value toJSON(const ScalarEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
ScalarEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of a tensor value.
struct TensorEncoding {
ScalarEncoding scalarEncoding;
};
bool fromJSON(const llvm::json::Value, TensorEncoding &, llvm::json::Path);
llvm::json::Value toJSON(const TensorEncoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
TensorEncoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encoding of either an input or output value of a circuit.
using Encoding = std::variant<TensorEncoding, ScalarEncoding>;
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
llvm::json::Value toJSON(const Encoding &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, Encoding e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
/// Represents the encodings of a circuit.
struct CircuitEncodings {
std::vector<Encoding> inputEncodings;
std::vector<Encoding> outputEncodings;
};
bool fromJSON(const llvm::json::Value, CircuitEncodings &, llvm::json::Path);
llvm::json::Value toJSON(const CircuitEncodings &);
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
CircuitEncodings e) {
return OS << llvm::formatv("{0:2}", toJSON(e));
}
llvm::Expected<CircuitEncodings> getCircuitEncodings(
llvm::StringRef functionName, mlir::ModuleOp module,
std::optional<::concretelang::clientlib::ChunkInfo> maybeChunkInfo);
void setCircuitEncodingModes(
Message<concreteprotocol::CircuitEncodingInfo> &info,
std::optional<
Message<concreteprotocol::IntegerCiphertextEncodingInfo::ChunkedMode>>
maybeChunk,
std::optional<V0FHEContext> maybeFheContext);
} // namespace encodings
} // namespace concretelang

View File

@@ -1,82 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_JIT_SUPPORT
#define CONCRETELANG_SUPPORT_JIT_SUPPORT
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <concretelang/Support/CompilerEngine.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/LambdaSupport.h>
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
/// JitCompilationResult is the result of a Jit compilation, the server JIT
/// lambda and the clientParameters.
struct JitCompilationResult {
std::shared_ptr<concretelang::JITLambda> lambda;
clientlib::ClientParameters clientParameters;
CompilationFeedback feedback;
};
/// JITSupport is the instantiated LambdaSupport for the Jit Compilation.
class JITSupport
: public LambdaSupport<std::shared_ptr<concretelang::JITLambda>,
JitCompilationResult> {
public:
JITSupport(std::optional<std::string> runtimeLibPath = std::nullopt);
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(llvm::SourceMgr &program, CompilationOptions options) override;
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(mlir::ModuleOp program,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
CompilationOptions options) override;
using LambdaSupport::compile;
llvm::Expected<std::shared_ptr<concretelang::JITLambda>>
loadServerLambda(JitCompilationResult &result) override {
return result.lambda;
}
llvm::Expected<clientlib::ClientParameters>
loadClientParameters(JitCompilationResult &result) override {
return result.clientParameters;
}
llvm::Expected<CompilationFeedback>
loadCompilationFeedback(JitCompilationResult &result) override {
return result.feedback;
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(std::shared_ptr<concretelang::JITLambda> lambda,
clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) override {
return lambda->call(args, evaluationKeys);
}
private:
template <typename T>
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compileWithEngine(T program, CompilationOptions options,
concretelang::CompilerEngine &engine);
std::optional<std::string> runtimeLibPath;
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -1,66 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef COMPILER_JIT_H
#define COMPILER_JIT_H
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/Support/LogicalResult.h>
#include <concretelang/ClientLib/KeySet.h>
#include <concretelang/ClientLib/PublicArguments.h>
namespace mlir {
namespace concretelang {
using ::concretelang::clientlib::CircuitGate;
using ::concretelang::clientlib::KeySet;
namespace clientlib = ::concretelang::clientlib;
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.
class JITLambda {
public:
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
: type(type), name(name){};
/// create a JITLambda that point to the function name of the given module.
/// Use runtimeLibPath as a shared library if specified.
static llvm::Expected<std::unique_ptr<JITLambda>>
create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
std::optional<std::string> runtimeLibPath = {});
/// Call the JIT lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
call(clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys);
void setUseDataflow(bool option) { this->useDataflow = option; }
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
/// used to store the result of the computation.
/// Example:
/// uin64_t arg0 = 1;
/// uin64_t res;
/// llvm::SmallVector<void *> args{&arg1, &res};
/// lambda.invokeRaw(args);
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
private:
mlir::LLVM::LLVMFunctionType type;
std::string name;
std::unique_ptr<mlir::ExecutionEngine> engine;
/// Tell if the DF parallelization was on or during compilation. This will be
/// useful to abort execution if the runtime doesn't support dataflow
/// execution, instead of having undefined symbol issues
bool useDataflow = false;
};
} // namespace concretelang
} // namespace mlir
#endif // COMPILER_JIT_H

View File

@@ -1,304 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_LAMBDA_ARGUMENT_H
#define CONCRETELANG_SUPPORT_LAMBDA_ARGUMENT_H
#include <cstdint>
#include <limits>
#include <concretelang/Support/Error.h>
#include <llvm/ADT/ArrayRef.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/ExtensibleRTTI.h>
namespace mlir {
namespace concretelang {
/// Abstract base class for lambda arguments
class LambdaArgument
: public llvm::RTTIExtends<LambdaArgument, llvm::RTTIRoot> {
public:
LambdaArgument(LambdaArgument &) = delete;
template <typename T> bool isa() const { return llvm::isa<T>(*this); }
/// Cast functions on constant instances
template <typename T> const T &cast() const { return llvm::cast<T>(*this); }
template <typename T> const T *dyn_cast() const {
return llvm::dyn_cast<T>(this);
}
/// Cast functions for mutable instances
template <typename T> T &cast() { return llvm::cast<T>(*this); }
template <typename T> T *dyn_cast() { return llvm::dyn_cast<T>(this); }
static char ID;
protected:
LambdaArgument(){};
};
/// Class for integer arguments. `BackingIntType` is used as the data
/// type to hold the argument's value. The precision is the actual
/// precision of the value, which might be different from the precision
/// of the backing integer type.
template <typename BackingIntType = uint64_t>
class IntLambdaArgument
: public llvm::RTTIExtends<IntLambdaArgument<BackingIntType>,
LambdaArgument> {
public:
typedef BackingIntType value_type;
IntLambdaArgument(BackingIntType value,
unsigned int precision = 8 * sizeof(BackingIntType))
: precision(precision) {
if (precision < 8 * sizeof(BackingIntType)) {
this->value = value & (1 << (this->precision - 1));
} else {
this->value = value;
}
}
unsigned int getPrecision() const { return this->precision; }
BackingIntType getValue() const { return this->value; }
template <typename OtherBackingIntType>
bool operator==(const IntLambdaArgument<OtherBackingIntType> &other) const {
return getValue() == other.getValue();
}
template <typename OtherBackingIntType>
bool operator!=(const IntLambdaArgument<OtherBackingIntType> &other) const {
return !(*this == other);
}
static char ID;
protected:
unsigned int precision;
BackingIntType value;
};
template <typename BackingIntType>
char IntLambdaArgument<BackingIntType>::ID = 0;
/// Class for encrypted integer arguments. `BackingIntType` is used as
/// the data type to hold the argument's plaintext value. The precision
/// is the actual precision of the value, which might be different from
/// the precision of the backing integer type.
template <typename BackingIntType = uint64_t>
class EIntLambdaArgument
: public llvm::RTTIExtends<EIntLambdaArgument<BackingIntType>,
IntLambdaArgument<BackingIntType>> {
public:
static char ID;
};
template <typename BackingIntType>
char EIntLambdaArgument<BackingIntType>::ID = 0;
namespace {
/// Calculates `accu *= factor` or returns an error if the result
/// would overflow
template <typename AccuT, typename ValT>
llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) {
static_assert(std::numeric_limits<AccuT>::is_integer &&
std::numeric_limits<ValT>::is_integer &&
!std::numeric_limits<AccuT>::is_signed &&
!std::numeric_limits<ValT>::is_signed,
"Only unsigned integers are supported");
const AccuT left = std::numeric_limits<AccuT>::max() / accu;
if (left > factor) {
accu *= factor;
return llvm::Error::success();
}
return StreamStringError("Multiplying value ")
<< accu << " with " << factor << " would cause an overflow";
}
} // namespace
/// Class for Tensor arguments. This can either be plaintext tensors
/// (for `ScalarArgumentT = IntLambaArgument<T>`) or tensors
/// representing encrypted integers (for `ScalarArgumentT =
/// EIntLambaArgument<T>`).
template <typename ScalarArgumentT>
class TensorLambdaArgument
: public llvm::RTTIExtends<TensorLambdaArgument<ScalarArgumentT>,
LambdaArgument> {
public:
typedef ScalarArgumentT scalar_type;
/// Construct tensor argument from the one-dimensional array `value`,
/// but interpreting the array's values as a linearized
/// multi-dimensional tensor with the sizes of the dimensions
/// specified in `dimensions`.
TensorLambdaArgument(
llvm::ArrayRef<typename ScalarArgumentT::value_type> value,
llvm::ArrayRef<int64_t> dimensions)
: dimensions(dimensions.vec()) {
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
}
/// Construct tensor argument by moving the values from the
/// one-dimensional vector `value`, but interpreting the array's
/// values as a linearized multi-dimensional tensor with the sizes
/// of the dimensions specified in `dimensions`.
TensorLambdaArgument(
std::vector<typename ScalarArgumentT::value_type> &&value,
llvm::ArrayRef<int64_t> dimensions)
: dimensions(dimensions.vec()), value(std::move(value)) {}
/// Construct a one-dimensional tensor argument from the
/// array `value`.
TensorLambdaArgument(
llvm::ArrayRef<typename ScalarArgumentT::value_type> value)
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
template <std::size_t size1, std::size_t size2>
TensorLambdaArgument(
typename ScalarArgumentT::value_type (&a)[size1][size2]) {
dimensions = {size1, size2};
auto value = llvm::MutableArrayRef<typename ScalarArgumentT::value_type>(
(typename ScalarArgumentT::value_type *)a, size1 * size2);
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
}
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
/// Returns the total number of elements in the tensor. If the number
/// of elements cannot be represented as a `size_t`, the method
/// returns an error.
llvm::Expected<size_t> getNumElements() const {
size_t accu = 1;
for (unsigned int dimSize : dimensions)
if (llvm::Error err = safeUnsignedMul(accu, dimSize))
return std::move(err);
return accu;
}
/// Returns a bare pointer to the linearized values of the tensor
/// (constant version).
const typename ScalarArgumentT::value_type *getValue() const {
return this->value.data();
}
/// Returns a bare pointer to the linearized values of the tensor (mutable
/// version).
typename ScalarArgumentT::value_type *getValue() {
return this->value.data();
}
template <typename OtherScalarArgumentT>
bool
operator==(const TensorLambdaArgument<OtherScalarArgumentT> &other) const {
if (getDimensions() != other.getDimensions())
return false;
for (auto pair : llvm::zip(value, other.value)) {
if (std::get<0>(pair) != std::get<1>(pair))
return false;
}
return true;
}
template <typename OtherScalarArgumentT>
bool
operator!=(const TensorLambdaArgument<OtherScalarArgumentT> &other) const {
return !(*this == other);
}
static char ID;
protected:
std::vector<typename ScalarArgumentT::value_type> value;
std::vector<int64_t> dimensions;
};
template <typename ScalarArgumentT>
char TensorLambdaArgument<ScalarArgumentT>::ID = 0;
namespace {
template <typename T> struct NameOfFundamentalType {
static const char *get();
};
template <> struct NameOfFundamentalType<uint8_t> {
static const char *get() { return "uint8_t"; }
};
template <> struct NameOfFundamentalType<int8_t> {
static const char *get() { return "int8_t"; }
};
template <> struct NameOfFundamentalType<uint16_t> {
static const char *get() { return "uint16_t"; }
};
template <> struct NameOfFundamentalType<int16_t> {
static const char *get() { return "int16_t"; }
};
template <> struct NameOfFundamentalType<uint32_t> {
static const char *get() { return "uint32_t"; }
};
template <> struct NameOfFundamentalType<int32_t> {
static const char *get() { return "int32_t"; }
};
template <> struct NameOfFundamentalType<uint64_t> {
static const char *get() { return "uint64_t"; }
};
template <> struct NameOfFundamentalType<int64_t> {
static const char *get() { return "int64_t"; }
};
template <typename... Ts> struct LambdaArgumentTypeName;
template <> struct LambdaArgumentTypeName<> {
static const char *get(const mlir::concretelang::LambdaArgument &arg) {
assert(false && "No name implemented for this lambda argument type");
return nullptr;
}
};
template <typename T, typename... Ts> struct LambdaArgumentTypeName<T, Ts...> {
static const std::string get(const mlir::concretelang::LambdaArgument &arg) {
if (arg.dyn_cast<const IntLambdaArgument<T>>()) {
return NameOfFundamentalType<T>::get();
} else if (arg.dyn_cast<const EIntLambdaArgument<T>>()) {
return std::string("encrypted ") + NameOfFundamentalType<T>::get();
} else if (arg.dyn_cast<
const TensorLambdaArgument<IntLambdaArgument<T>>>()) {
return std::string("tensor<") + NameOfFundamentalType<T>::get() + ">";
} else if (arg.dyn_cast<
const TensorLambdaArgument<EIntLambdaArgument<T>>>()) {
return std::string("tensor<encrypted ") +
NameOfFundamentalType<T>::get() + ">";
}
return LambdaArgumentTypeName<Ts...>::get(arg);
}
};
} // namespace
static inline const std::string
getLambdaArgumentTypeAsString(const LambdaArgument &arg) {
return LambdaArgumentTypeName<int8_t, uint8_t, int16_t, uint16_t, int32_t,
uint32_t, int64_t, uint64_t>::get(arg);
}
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -1,541 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_LAMBDASUPPORT
#define CONCRETELANG_SUPPORT_LAMBDASUPPORT
#include "boost/outcome.h"
#include "concretelang/Support/LambdaArgument.h"
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
#include "concretelang/ServerLib/ServerLambda.h"
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
namespace {
// Generic function template as well as specializations of
// `typedResult` must be declared at namespace scope due to return
// type template specialization
/// Helper function for implementing type-dependent preparation of the result.
template <typename ResT>
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result);
template <typename T>
inline llvm::Expected<T> typedScalarResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
auto clearResult = result.asClearTextScalar<T>(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedResult cannot get clear text scalar")
<< clearResult.error().mesg;
}
return clearResult.value();
}
/// Specializations of `typedResult()` for scalar results, forwarding
/// scalar value to caller.
template <>
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
return typedScalarResult<uint64_t>(keySet, result);
}
template <>
inline llvm::Expected<int64_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
return typedScalarResult<int64_t>(keySet, result);
}
template <>
inline llvm::Expected<uint32_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
return typedScalarResult<uint32_t>(keySet, result);
}
template <>
inline llvm::Expected<int32_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
return typedScalarResult<int32_t>(keySet, result);
}
template <>
inline llvm::Expected<uint16_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
return typedScalarResult<uint16_t>(keySet, result);
}
template <>
inline llvm::Expected<int16_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
return typedScalarResult<int16_t>(keySet, result);
}
template <>
inline llvm::Expected<uint8_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
return typedScalarResult<uint8_t>(keySet, result);
}
template <>
inline llvm::Expected<int8_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
return typedScalarResult<int8_t>(keySet, result);
}
template <typename T>
inline llvm::Expected<std::vector<T>>
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector<T>(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedVectorResult cannot get clear text vector")
<< clearResult.error().mesg;
}
return std::move(clearResult.value());
}
/// Specializations of `typedResult()` for vector results, initializing
/// an `std::vector` of the right size with the results and forwarding
/// it to the caller with move semantics.
/// Cannot factor out into a template template <typename T> inline
/// llvm::Expected<std::vector<uint8_t>>
/// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
/// to ambiguity with scalar template
template <>
inline llvm::Expected<std::vector<uint8_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint8_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<int8_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<int8_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<uint16_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint16_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<int16_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<int16_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<uint32_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint32_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<int32_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<int32_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint64_t>(keySet, result);
}
template <>
inline llvm::Expected<std::vector<int64_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<int64_t>(keySet, result);
}
template <typename T>
llvm::Expected<std::unique_ptr<LambdaArgument>>
buildTensorLambdaResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
llvm::Expected<std::vector<T>> tensorOrError =
typedResult<std::vector<T>>(keySet, result);
if (auto err = tensorOrError.takeError())
return std::move(err);
auto tensorDim = result.asClearTextShape(0);
if (tensorDim.has_error())
return StreamStringError(tensorDim.error().mesg);
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
*tensorOrError, tensorDim.value());
}
template <typename T>
llvm::Expected<std::unique_ptr<LambdaArgument>>
buildScalarLambdaResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
llvm::Expected<T> scalarOrError = typedResult<T>(keySet, result);
if (auto err = scalarOrError.takeError())
return std::move(err);
return std::make_unique<IntLambdaArgument<T>>(*scalarOrError);
}
/// pecialization of `typedResult()` for a single result wrapped into
/// a `LambdaArgument`.
template <>
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto gate = keySet.outputGate(0);
auto width = gate.shape.width;
bool sign = gate.shape.sign;
if (width > 64)
return StreamStringError("Cannot handle values with more than 64 bits");
// By convention, decrypted integers are always 64 bits wide
if (gate.isEncrypted())
width = 64;
if (gate.shape.dimensions.empty()) {
// scalar case
if (width > 32) {
return (sign) ? buildScalarLambdaResult<int64_t>(keySet, result)
: buildScalarLambdaResult<uint64_t>(keySet, result);
} else if (width > 16) {
return (sign) ? buildScalarLambdaResult<int32_t>(keySet, result)
: buildScalarLambdaResult<uint32_t>(keySet, result);
} else if (width > 8) {
return (sign) ? buildScalarLambdaResult<int16_t>(keySet, result)
: buildScalarLambdaResult<uint16_t>(keySet, result);
} else if (width <= 8) {
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
: buildScalarLambdaResult<uint8_t>(keySet, result);
}
} else if (gate.chunkInfo.has_value()) {
// chunked scalar case
assert(gate.shape.dimensions.size() == 1);
width = gate.shape.size * gate.chunkInfo->width;
if (width > 32) {
return (sign) ? buildScalarLambdaResult<int64_t>(keySet, result)
: buildScalarLambdaResult<uint64_t>(keySet, result);
} else if (width > 16) {
return (sign) ? buildScalarLambdaResult<int32_t>(keySet, result)
: buildScalarLambdaResult<uint32_t>(keySet, result);
} else if (width > 8) {
return (sign) ? buildScalarLambdaResult<int16_t>(keySet, result)
: buildScalarLambdaResult<uint16_t>(keySet, result);
} else if (width <= 8) {
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
: buildScalarLambdaResult<uint8_t>(keySet, result);
}
} else {
// tensor case
if (width > 32) {
return (sign) ? buildTensorLambdaResult<int64_t>(keySet, result)
: buildTensorLambdaResult<uint64_t>(keySet, result);
} else if (width > 16) {
return (sign) ? buildTensorLambdaResult<int32_t>(keySet, result)
: buildTensorLambdaResult<uint32_t>(keySet, result);
} else if (width > 8) {
return (sign) ? buildTensorLambdaResult<int16_t>(keySet, result)
: buildTensorLambdaResult<uint16_t>(keySet, result);
} else if (width <= 8) {
return (sign) ? buildTensorLambdaResult<int8_t>(keySet, result)
: buildTensorLambdaResult<uint8_t>(keySet, result);
}
}
assert(false && "Cannot happen");
}
} // namespace
/// Adaptor class that push arguments specified as instances of
/// `LambdaArgument` to `clientlib::EncryptedArguments`.
class LambdaArgumentAdaptor {
public:
/// Checks if the argument `arg` is an plaintext / encrypted integer
/// argument or a plaintext / encrypted tensor argument with a
/// backing integer type `IntT` and push the argument to `encryptedArgs`.
///
/// Returns `true` if `arg` has one of the types above and its value
/// was successfully added to `encryptedArgs`, `false` if none of the types
/// matches or an error if a type matched, but adding the argument to
/// `encryptedArgs` failed.
template <typename IntT>
static inline llvm::Expected<bool>
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
if (!res.has_value()) {
return StreamStringError(res.error().mesg);
} else {
return true;
}
} else if (auto tla = arg.dyn_cast<
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
auto res =
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
if (!res.has_value()) {
return StreamStringError(res.error().mesg);
} else {
return true;
}
}
return false;
}
/// Recursive case for `tryAddArg<IntT>(...)`
template <typename IntT, typename NextIntT, typename... IntTs>
static inline llvm::Expected<bool>
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
llvm::Expected<bool> successOrError =
tryAddArg<IntT>(encryptedArgs, arg, keySet);
if (!successOrError)
return successOrError.takeError();
if (successOrError.get() == false)
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
else
return true;
}
/// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
/// error if either the argument type is unsupported or if the argument types
/// is supported, but adding it to `encryptedArgs` failed.
static inline llvm::Error
addArgument(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
// Try the supported integer types; size_t needs explicit
// treatment, since it may alias none of the fixed size integer
// types
llvm::Expected<bool> successOrError =
LambdaArgumentAdaptor::tryAddArg<int64_t, int32_t, int16_t, int8_t,
uint64_t, uint32_t, uint16_t, uint8_t,
size_t>(encryptedArgs, arg, keySet);
if (!successOrError)
return successOrError.takeError();
if (successOrError.get() == false)
return StreamStringError("Unknown argument type");
else
return llvm::Error::success();
}
/// Encrypts and build public arguments from lambda arguments
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
exportArguments(llvm::ArrayRef<const LambdaArgument *> args,
clientlib::ClientParameters clientParameters,
clientlib::KeySet &keySet) {
auto encryptedArgs = clientlib::EncryptedArguments::empty();
for (auto arg : args) {
if (auto err = LambdaArgumentAdaptor::addArgument(*encryptedArgs, *arg,
keySet)) {
return std::move(err);
}
}
auto check = encryptedArgs->checkAllArgs(keySet);
if (check.has_error()) {
return StreamStringError(check.error().mesg);
}
auto publicArguments =
encryptedArgs->exportPublicArguments(clientParameters);
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}
return std::move(publicArguments.value());
}
};
template <typename Lambda, typename CompilationResult> class LambdaSupport {
public:
typedef Lambda lambda;
typedef CompilationResult compilationResult;
virtual ~LambdaSupport() {}
/// Compile the mlir program and produces a compilation result if succeed.
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
llvm::SourceMgr &program,
CompilationOptions options = CompilationOptions("main")) = 0;
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
mlir::ModuleOp program,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
CompilationOptions options = CompilationOptions("main")) = 0;
llvm::Expected<std::unique_ptr<CompilationResult>>
compile(llvm::StringRef program,
CompilationOptions options = CompilationOptions("main")) {
return compile(llvm::MemoryBuffer::getMemBuffer(program), options);
}
llvm::Expected<std::unique_ptr<CompilationResult>>
compile(std::unique_ptr<llvm::MemoryBuffer> program,
CompilationOptions options = CompilationOptions("main")) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc());
return compile(sm, options);
}
/// Load the server lambda from the compilation result.
llvm::Expected<Lambda> virtual loadServerLambda(
CompilationResult &result) = 0;
/// Load the client parameters from the compilation result.
llvm::Expected<clientlib::ClientParameters> virtual loadClientParameters(
CompilationResult &result) = 0;
/// Load the compilation feedback from the compilation result.
llvm::Expected<CompilationFeedback> virtual loadCompilationFeedback(
CompilationResult &result) = 0;
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>> virtual serverCall(
Lambda lambda, clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) = 0;
/// Build the client KeySet from the client parameters.
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
keySet(clientlib::ClientParameters clientParameters,
std::optional<clientlib::KeySetCache> cache, uint64_t seed_msb = 0,
uint64_t seed_lsb = 0) {
std::shared_ptr<clientlib::KeySetCache> cachePtr;
if (cache.has_value()) {
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.value());
}
auto keySet = clientlib::KeySetCache::generate(cachePtr, clientParameters,
seed_msb, seed_lsb);
if (keySet.has_error()) {
return StreamStringError(keySet.error().mesg);
}
return std::move(keySet.value());
}
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
exportArguments(clientlib::ClientParameters clientParameters,
clientlib::KeySet &keySet,
llvm::ArrayRef<const LambdaArgument *> args) {
return LambdaArgumentAdaptor::exportArguments(args, clientParameters,
keySet);
}
template <typename ResT>
static llvm::Expected<ResT> call(Lambda lambda,
clientlib::PublicArguments &publicArguments,
clientlib::EvaluationKeys &evaluationKeys) {
// Call the lambda
auto publicResult = LambdaSupport<Lambda, CompilationResult>().serverCall(
lambda, publicArguments, evaluationKeys);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
// Decrypt the result
return typedResult<ResT>(keySet, **publicResult);
}
};
template <class LambdaSupport> class ClientServer {
public:
static llvm::Expected<ClientServer>
create(llvm::StringRef program,
CompilationOptions options = CompilationOptions("main"),
std::optional<clientlib::KeySetCache> cache = {},
LambdaSupport support = LambdaSupport()) {
auto compilationResult = support.compile(program, options);
if (auto err = compilationResult.takeError()) {
return std::move(err);
}
auto lambda = support.loadServerLambda(**compilationResult);
if (auto err = lambda.takeError()) {
return std::move(err);
}
auto clientParameters = support.loadClientParameters(**compilationResult);
if (auto err = clientParameters.takeError()) {
return std::move(err);
}
auto keySet = support.keySet(*clientParameters, cache);
if (auto err = keySet.takeError()) {
return std::move(err);
}
auto f = ClientServer();
f.lambda = *lambda;
f.compilationResult = std::move(*compilationResult);
f.keySet = std::move(*keySet);
f.clientParameters = *clientParameters;
f.support = support;
return std::move(f);
}
template <typename ResT = uint64_t>
llvm::Expected<ResT> operator()(llvm::ArrayRef<LambdaArgument *> args) {
auto publicArguments = LambdaArgumentAdaptor::exportArguments(
args, clientParameters, *this->keySet);
if (auto err = publicArguments.takeError()) {
return std::move(err);
}
auto evaluationKeys = this->keySet->evaluationKeys();
auto publicResult =
support.serverCall(lambda, **publicArguments, evaluationKeys);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
return typedResult<ResT>(*keySet, **publicResult);
}
template <typename T, typename ResT = uint64_t>
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
auto encryptedArgs = clientlib::EncryptedArguments::create(
/*simulation*/ false, *keySet, args);
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
auto publicArguments =
encryptedArgs.value()->exportPublicArguments(clientParameters);
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
}
auto evaluationKeys = keySet->evaluationKeys();
auto publicResult =
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
return typedResult<ResT>(*keySet, **publicResult);
}
template <typename ResT = uint64_t, typename... Args>
llvm::Expected<ResT> operator()(const Args... args) {
auto encryptedArgs = clientlib::EncryptedArguments::create(
/*simulation*/ false, *keySet, args...);
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
auto publicArguments =
encryptedArgs.value()->exportPublicArguments(clientParameters);
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}
auto evaluationKeys = keySet->evaluationKeys();
auto publicResult =
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
return typedResult<ResT>(*keySet, **publicResult);
}
private:
typename LambdaSupport::lambda lambda;
std::unique_ptr<typename LambdaSupport::compilationResult> compilationResult;
std::unique_ptr<clientlib::KeySet> keySet;
clientlib::ClientParameters clientParameters;
LambdaSupport support;
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -1,193 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_LIBRARY_SUPPORT
#define CONCRETELANG_SUPPORT_LIBRARY_SUPPORT
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <concretelang/ServerLib/ServerLambda.h>
#include <concretelang/Support/CompilerEngine.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/LambdaSupport.h>
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
namespace serverlib = ::concretelang::serverlib;
/// LibraryCompilationResult is the result of a compilation to a library.
struct LibraryCompilationResult {
/// The output directory path where the compilation artifacts have been
/// generated.
std::string outputDirPath;
std::string funcName;
};
class LibrarySupport
: public LambdaSupport<serverlib::ServerLambda, LibraryCompilationResult> {
public:
LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "",
bool generateSharedLib = true, bool generateStaticLib = true,
bool generateClientParameters = true,
bool generateCompilationFeedback = true,
bool generateCppHeader = true)
: outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath),
generateSharedLib(generateSharedLib),
generateStaticLib(generateStaticLib),
generateClientParameters(generateClientParameters),
generateCompilationFeedback(generateCompilationFeedback),
generateCppHeader(generateCppHeader) {}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(llvm::SourceMgr &program, CompilationOptions options) override {
// Setup the compiler engine
auto context = CompilationContext::createShared();
concretelang::CompilerEngine engine(context);
engine.setCompilationOptions(options);
return compileWithEngine<llvm::SourceMgr &>(program, options, engine);
}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(mlir::ModuleOp program,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
CompilationOptions options) override {
// Setup the compiler engine
concretelang::CompilerEngine engine(cctx);
engine.setCompilationOptions(options);
return compileWithEngine<mlir::ModuleOp>(program, options, engine);
}
using LambdaSupport::compile;
/// Load the server lambda from the compilation result.
llvm::Expected<serverlib::ServerLambda>
loadServerLambda(LibraryCompilationResult &result) override {
auto lambda =
serverlib::ServerLambda::load(result.funcName, result.outputDirPath);
if (lambda.has_error()) {
return StreamStringError(lambda.error().mesg);
}
return lambda.value();
}
/// Load the client parameters from the compilation result.
llvm::Expected<clientlib::ClientParameters>
loadClientParameters(LibraryCompilationResult &result) override {
auto path =
CompilerEngine::Library::getClientParametersPath(result.outputDirPath);
auto params = ClientParameters::load(path);
if (params.has_error()) {
return StreamStringError(params.error().mesg);
}
auto param = llvm::find_if(params.value(), [&](ClientParameters param) {
return param.functionName == result.funcName;
});
if (param == params.value().end()) {
return StreamStringError("ClientLambda: cannot find function(")
<< result.funcName << ") in client parameters path(" << path
<< ")";
}
return *param;
}
std::string getFuncName() {
auto path = CompilerEngine::Library::getClientParametersPath(outputPath);
auto params = ClientParameters::load(path);
if (params.has_error() || params.value().empty()) {
return "";
}
return params.value().front().functionName;
}
/// Load the the compilation result if circuit already compiled
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
loadCompilationResult() {
auto funcName = getFuncName();
if (funcName.empty()) {
return StreamStringError("couldn't find function name");
}
auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = funcName;
return std::move(result);
}
llvm::Expected<CompilationFeedback>
loadCompilationFeedback(LibraryCompilationResult &result) override {
auto path = CompilerEngine::Library::getCompilationFeedbackPath(
result.outputDirPath);
auto feedback = CompilationFeedback::load(path);
if (feedback.has_error()) {
return StreamStringError(feedback.error().mesg);
}
return feedback.value();
}
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(serverlib::ServerLambda lambda, clientlib::PublicArguments &args,
clientlib::EvaluationKeys &evaluationKeys) override {
return lambda.call(args, evaluationKeys);
}
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
simulate(serverlib::ServerLambda lambda, clientlib::PublicArguments &args) {
return lambda.call(args, {}, /*simulation*/ true);
}
/// Get path to shared library
std::string getSharedLibPath() {
return CompilerEngine::Library::getSharedLibraryPath(outputPath);
}
/// Get path to client parameters file
std::string getClientParametersPath() {
return CompilerEngine::Library::getClientParametersPath(outputPath);
}
private:
std::string outputPath;
std::string runtimeLibraryPath;
/// Flags to select generated artifacts
bool generateSharedLib;
bool generateStaticLib;
bool generateClientParameters;
bool generateCompilationFeedback;
bool generateCppHeader;
template <typename T>
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compileWithEngine(T program, CompilationOptions options,
concretelang::CompilerEngine &engine) {
// Compile to a library
auto library = engine.compile(
program, outputPath, runtimeLibraryPath, generateSharedLib,
generateStaticLib, generateClientParameters,
generateCompilationFeedback, generateCppHeader);
if (auto err = library.takeError()) {
return std::move(err);
}
if (!options.clientParametersFuncName.has_value()) {
return StreamStringError("Need to have a funcname to compile library");
}
auto result = std::make_unique<LibraryCompilationResult>();
result->outputDirPath = outputPath;
result->funcName = *options.clientParametersFuncName;
return std::move(result);
}
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -6,12 +6,11 @@
#ifndef CONCRETELANG_SUPPORT_PIPELINE_H_
#define CONCRETELANG_SUPPORT_PIPELINE_H_
#include <llvm/IR/Module.h>
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Transforms/Passes.h>
#include <concretelang/Support/V0Parameters.h>
#include "concretelang/Support/V0Parameters.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/IR/Module.h"
namespace mlir {
namespace concretelang {
@@ -109,8 +108,7 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
mlir::LogicalResult
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool simulation);
std::function<bool(mlir::Pass *)> enablePass);
mlir::LogicalResult
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,

View File

@@ -0,0 +1,30 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_PROGRAMINFOGENERATION_H_
#define CONCRETELANG_SUPPORT_PROGRAMINFOGENERATION_H_
#include "concrete-protocol.capnp.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Support/Encodings.h"
#include "concretelang/Support/V0Parameters.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinOps.h"
#include <memory>
using concretelang::protocol::Message;
namespace mlir {
namespace concretelang {
llvm::Expected<Message<concreteprotocol::ProgramInfo>>
createProgramInfoFromTfheDialect(
mlir::ModuleOp module, llvm::StringRef functionName, int bitsOfSecurity,
Message<concreteprotocol::CircuitEncodingInfo> &encodings);
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -6,14 +6,10 @@
#ifndef CONCRETELANG_SUPPORT_UTILS_H_
#define CONCRETELANG_SUPPORT_UTILS_H_
#include <concretelang/ClientLib/ClientLambda.h>
#include <concretelang/ClientLib/KeySet.h>
#include <concretelang/ClientLib/PublicArguments.h>
#include <concretelang/ClientLib/Serializers.h>
#include <concretelang/Runtime/context.h>
#include <concretelang/ServerLib/ServerLambda.h>
#include <concretelang/Support/Error.h>
#include <llvm/ADT/SmallVector.h>
#include "concrete-protocol.capnp.h"
#include "concretelang/Runtime/context.h"
#include "concretelang/Support/Error.h"
#include "llvm/ADT/SmallVector.h"
namespace concretelang {
@@ -28,161 +24,6 @@ std::string makePackedFunctionName(llvm::StringRef name);
// and two array of rank size for sizes and strides.
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank);
template <typename Lambda>
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
std::vector<void *> preparedInputArgs,
std::optional<clientlib::EvaluationKeys> evaluationKeys,
bool simulation = false) {
// invokeRaw needs to have pointers on arguments and a pointers on the result
// as last argument.
// Prepare the outputs vector to store the output value of the lambda.
auto numOutputs = 0;
for (auto &output : clientParameters.outputs) {
auto shape = clientParameters.bufferShape(output, simulation);
if (shape.size() == 0) {
// scalar gate
numOutputs += 1;
} else {
// buffer gate
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
}
}
std::vector<uint64_t> outputs(numOutputs);
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
// inputs and outputs.
std::vector<void *> rawArgs(preparedInputArgs.size() +
(simulation ? 0 : 1) /*runtime context*/ +
1 /* outputs */
);
size_t i = 0;
// Pointers on inputs
for (auto &arg : preparedInputArgs) {
rawArgs[i++] = &arg;
}
// Some calls require the runtime context while others don't (in simulation)
std::unique_ptr<mlir::concretelang::RuntimeContext> runtimeContext;
mlir::concretelang::RuntimeContext *rtCtxPtr;
if (!simulation) {
assert(evaluationKeys.has_value() &&
"evaluation keys are required if not in simulation");
runtimeContext = std::make_unique<mlir::concretelang::RuntimeContext>(
evaluationKeys.value());
// Pointer on runtime context, the rawArgs take pointer on actual value that
// is passed to the compiled function.
rtCtxPtr = runtimeContext.get();
rawArgs[i++] = &rtCtxPtr;
}
// Outputs
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
// Invoke
if (auto err = lambda->invokeRaw(rawArgs)) {
return std::move(err);
}
// Store the result to the PublicResult
std::vector<clientlib::SharedScalarOrTensorData> buffers;
{
size_t outputOffset = 0;
for (auto &output : clientParameters.outputs) {
auto shape = clientParameters.bufferShape(output, simulation);
if (shape.size() == 0) {
if (simulation) {
// value is encrypted (simulated)
auto value = concretelang::clientlib::ScalarOrTensorData(
concretelang::clientlib::ScalarData(
outputs[outputOffset++],
clientlib::EncryptedScalarElementType,
clientlib::EncryptedScalarElementWidth));
auto sharedValue =
clientlib::SharedScalarOrTensorData(std::move(value));
buffers.push_back(sharedValue);
} else {
// plain scalar
auto value = concretelang::clientlib::ScalarOrTensorData(
concretelang::clientlib::ScalarData(outputs[outputOffset++],
output.shape.sign,
output.shape.width));
auto sharedValue =
clientlib::SharedScalarOrTensorData(std::move(value));
buffers.push_back(sharedValue);
}
} else {
// buffer gate
auto rank = shape.size();
auto allocated = (uint64_t *)outputs[outputOffset++];
auto aligned = (uint64_t *)outputs[outputOffset++];
auto offset = (size_t)outputs[outputOffset++];
size_t *sizes = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t *strides = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t elementWidth = (output.isEncrypted())
? clientlib::EncryptedScalarElementWidth
: output.shape.width;
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
auto value = concretelang::clientlib::ScalarOrTensorData(
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
aligned, offset, sizes, strides));
auto sharedValue =
clientlib::SharedScalarOrTensorData(std::move(value));
buffers.push_back(sharedValue);
}
}
}
return clientlib::PublicResult::fromBuffers(clientParameters,
std::move(buffers));
}
template <typename Lambda>
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
invokeRawOnLambda(Lambda *lambda, clientlib::PublicArguments &arguments,
std::optional<clientlib::EvaluationKeys> evaluationKeys,
bool simulation = false) {
// Prepare arguments with the right calling convention
std::vector<void *> preparedArgs;
for (auto &sharedArg : arguments.getArguments()) {
clientlib::ScalarOrTensorData &arg = sharedArg.get();
if (arg.isScalar()) {
auto scalar = arg.getScalar().getValueAsU64();
preparedArgs.push_back((void *)scalar);
} else {
clientlib::TensorData &td = arg.getTensor();
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back(td.getValuesAsOpaquePointer());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t size : td.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = td.getNumElements();
for (size_t size : td.getDimensions()) {
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}
}
}
return invokeRawOnLambda(lambda, arguments.getClientParameters(),
preparedArgs, evaluationKeys, simulation);
}
template <typename V, unsigned int N>
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const llvm::SmallVector<V, N> vect) {

View File

@@ -0,0 +1,186 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_TESTLIB_TESTCIRCUIT_H
#define CONCRETELANG_TESTLIB_TESTCIRCUIT_H
#include "boost/outcome.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/ClientLib/ClientLib.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Keysets.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Common/Values.h"
#include "concretelang/ServerLib/ServerLib.h"
#include "concretelang/Support/CompilerEngine.h"
#include "tests_tools/keySetCache.h"
#include "llvm/Support/Path.h"
#include <filesystem>
#include <memory>
#include <ostream>
#include <string>
#include <thread>
using concretelang::clientlib::ClientCircuit;
using concretelang::clientlib::ClientProgram;
using concretelang::csprng::ConcreteCSPRNG;
using concretelang::error::Result;
using concretelang::keysets::Keyset;
using concretelang::serverlib::ServerCircuit;
using concretelang::serverlib::ServerProgram;
using concretelang::values::TransportValue;
using concretelang::values::Value;
namespace concretelang {
namespace testlib {
class TestCircuit {
public:
static Result<TestCircuit>
create(Keyset keyset, Message<concreteprotocol::ProgramInfo> programInfo,
std::string sharedLibPath, uint64_t seedMsb, uint64_t seedLsb,
bool useSimulation = false) {
OUTCOME_TRY(auto serverProgram,
ServerProgram::load(programInfo, sharedLibPath, useSimulation));
OUTCOME_TRY(auto serverCircuit,
serverProgram.getServerCircuit(
programInfo.asReader().getCircuits()[0].getName()));
__uint128_t seed = seedMsb;
seed <<= 64;
seed += seedLsb;
std::shared_ptr<CSPRNG> csprng = std::make_shared<ConcreteCSPRNG>(seed);
OUTCOME_TRY(auto clientProgram,
ClientProgram::create(programInfo, keyset.client, csprng,
useSimulation));
OUTCOME_TRY(auto clientCircuit,
clientProgram.getClientCircuit(
programInfo.asReader().getCircuits()[0].getName()));
auto artifactFolder = std::filesystem::path(sharedLibPath).parent_path();
return TestCircuit(clientCircuit, serverCircuit, useSimulation,
artifactFolder, keyset);
}
TestCircuit(ClientCircuit clientCircuit, ServerCircuit serverCircuit,
bool useSimulation, Keyset keyset)
: clientCircuit(clientCircuit), serverCircuit(serverCircuit),
useSimulation(useSimulation), keyset(keyset) {}
Result<std::vector<Value>> call(std::vector<Value> inputs) {
auto preparedArgs = std::vector<TransportValue>();
for (size_t i = 0; i < inputs.size(); i++) {
OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i));
preparedArgs.push_back(preparedInput);
}
std::vector<TransportValue> returns;
if (useSimulation) {
OUTCOME_TRY(returns, serverCircuit.simulate(preparedArgs));
} else {
OUTCOME_TRY(returns, serverCircuit.call(keyset.server, preparedArgs));
}
std::vector<Value> processedOutputs(returns.size());
for (size_t i = 0; i < processedOutputs.size(); i++) {
OUTCOME_TRY(processedOutputs[i],
clientCircuit.processOutput(returns[i], i));
}
return processedOutputs;
}
std::string getArtifactFolder() { return artifactFolder; }
private:
TestCircuit(ClientCircuit clientCircuit, ServerCircuit serverCircuit,
bool useSimulation, std::string artifactFolder, Keyset keyset)
: clientCircuit(clientCircuit), serverCircuit(serverCircuit),
useSimulation(useSimulation), artifactFolder(artifactFolder),
keyset(keyset){};
ClientCircuit clientCircuit;
ServerCircuit serverCircuit;
bool useSimulation;
std::string artifactFolder;
Keyset keyset;
};
TestCircuit load(mlir::concretelang::CompilerEngine::Library compiled) {
auto keyset =
getTestKeySetCachePtr()
->getKeyset(compiled.getProgramInfo().asReader().getKeyset(), 0, 0)
.value();
return TestCircuit::create(
keyset, compiled.getProgramInfo().asReader(),
compiled.getSharedLibraryPath(compiled.getOutputDirPath()), 0, 0,
false)
.value();
}
const std::string FUNCNAME = "main";
std::string getSystemTempFolderPath() {
llvm::SmallString<0> tempPath;
llvm::sys::path::system_temp_directory(true, tempPath);
return std::string(tempPath);
}
std::string createTempFolderIn(const std::string &rootFolder) {
std::srand(std::time(nullptr));
auto new_path = [=]() {
llvm::SmallString<0> outputPath;
llvm::sys::path::append(outputPath, rootFolder);
std::string uid = std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
uid.append("-");
uid.append(std::to_string(std::rand()));
llvm::sys::path::append(outputPath, uid);
return std::string(outputPath);
};
// Macos sometimes fail to create new directories. We have to retry a few
// times.
for (size_t i = 0; i < 5; i++) {
auto pathString = new_path();
auto ec = std::error_code();
if (!std::filesystem::create_directory(pathString, ec)) {
std::cout << "Failed to create directory ";
std::cout << pathString;
std::cout << " Reason: ";
std::cout << ec.message();
std::cout << " Retrying....\n";
std::cout.flush();
} else {
std::cout << "Using artifact folder ";
std::cout << pathString;
std::cout << "\n";
std::cout.flush();
return pathString;
}
}
std::cout << "Failed to create temp directory 5 times. Aborting...\n";
std::cout.flush();
assert(false);
}
void deleteFolder(const std::string &folder) {
if (!folder.empty()) {
auto ec = std::error_code();
if (!std::filesystem::remove_all(folder, ec)) {
std::cout << "Failed to delete directory ";
std::cout << folder;
std::cout << " Reason: ";
std::cout << ec.message();
std::cout.flush();
assert(false);
}
}
}
std::vector<uint8_t> values_3bits() { return {0, 1, 2, 5, 7}; }
std::vector<uint8_t> values_6bits() { return {0, 1, 2, 13, 22, 59, 62, 63}; }
std::vector<uint8_t> values_7bits() { return {0, 1, 2, 63, 64, 65, 125, 126}; }
} // namespace testlib
} // namespace concretelang
#endif

View File

@@ -1,117 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_TESTLIB_TEST_TYPED_LAMBDA_H
#define CONCRETELANG_TESTLIB_TEST_TYPED_LAMBDA_H
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
#include "concretelang/ServerLib/ServerLambda.h"
namespace concretelang {
namespace testlib {
using concretelang::clientlib::ClientLambda;
using concretelang::clientlib::ClientParameters;
using concretelang::clientlib::KeySet;
using concretelang::clientlib::KeySetCache;
using concretelang::error::StringError;
using concretelang::serverlib::ServerLambda;
inline void freeStringMemory(std::string &s) {
std::string empty;
s.swap(empty);
}
template <typename Result, typename... Args>
class TestTypedLambda
: public concretelang::clientlib::TypedClientLambda<Result, Args...> {
template <typename Result_, typename... Args_>
using TypedClientLambda =
concretelang::clientlib::TypedClientLambda<Result_, Args_...>;
public:
static outcome::checked<TestTypedLambda, StringError>
load(std::string funcName, std::string outputLib, uint64_t seed_msb = 0,
uint64_t seed_lsb = 0,
std::shared_ptr<KeySetCache> unsecure_cache = nullptr) {
std::string jsonPath =
mlir::concretelang::CompilerEngine::Library::getClientParametersPath(
outputLib);
OUTCOME_TRY(auto cLambda, ClientLambda::load(funcName, jsonPath));
OUTCOME_TRY(auto sLambda, ServerLambda::load(funcName, outputLib));
OUTCOME_TRY(std::shared_ptr<KeySet> keySet,
KeySetCache::generate(unsecure_cache, cLambda.clientParameters,
seed_msb, seed_lsb));
return TestTypedLambda(cLambda, sLambda, keySet);
}
TestTypedLambda(ClientLambda &cLambda, ServerLambda &sLambda,
std::shared_ptr<KeySet> keySet)
: TypedClientLambda<Result, Args...>(cLambda), serverLambda(sLambda),
keySet(keySet) {}
TestTypedLambda(TypedClientLambda<Result, Args...> &cLambda,
ServerLambda &sLambda, std::shared_ptr<KeySet> keySet)
: TypedClientLambda<Result, Args...>(cLambda), serverLambda(sLambda),
keySet(keySet) {}
outcome::checked<Result, StringError> call(Args... args) {
// std::string message;
// client stream
// std::ostringstream clientOuput(std::ios::binary);
// client argument encryption
OUTCOME_TRY(auto encryptedArgs,
clientlib::EncryptedArguments::create(/*simulation*/ false,
*keySet, args...));
OUTCOME_TRY(auto publicArgument,
encryptedArgs->exportPublicArguments(this->clientParameters));
// client argument serialization
// publicArgument->serialize(clientOuput);
// message = clientOuput.str();
// server stream
// std::istringstream serverInput(message, std::ios::binary);
// freeStringMemory(message);
//
// OUTCOME_TRY(auto publicArguments,
// clientlib::PublicArguments::unserialize(
// this->clientParameters,
// serverInput));
// server function call
auto evaluationKeys = keySet->evaluationKeys();
auto publicResult = serverLambda.call(*publicArgument, evaluationKeys);
if (!publicResult) {
return StringError("failed calling function");
}
// client result decryption
return this->decryptResult(*keySet, *(publicResult.get()));
}
private:
ServerLambda serverLambda;
std::shared_ptr<KeySet> keySet;
};
template <typename Result, typename... Args>
static TestTypedLambda<Result, Args...> TestTypedLambdaFrom(
concretelang::clientlib::TypedClientLambda<Result, Args...> &cLambda,
ServerLambda &sLambda, std::shared_ptr<KeySet> keySet) {
return TestTypedLambda<Result, Args...>(cLambda, sLambda, keySet);
}
} // namespace testlib
} // namespace concretelang
#endif

View File

@@ -1,3 +1,5 @@
add_compile_options(-fsized-deallocation)
add_mlir_library(
AnalysisUtils
Utils.cpp

View File

@@ -8,7 +8,14 @@ add_compile_options(-fexceptions)
# ######################################################################################################################
set(LLVM_OPTIONAL_SOURCES CompilerAPIModule.cpp ConcretelangModule.cpp FHEModule.cpp)
add_mlir_public_c_api_library(CONCRETELANGPySupport CompilerEngine.cpp LINK_LIBS PUBLIC MLIRCAPIIR ConcretelangSupport)
add_mlir_public_c_api_library(
CONCRETELANGPySupport
CompilerEngine.cpp
LINK_LIBS
PUBLIC
MLIRCAPIIR
ConcretelangSupport
ConcretelangCommon)
# ######################################################################################################################
# Decalare native Python extension
@@ -28,9 +35,6 @@ declare_mlir_python_extension(
CompilerAPIModule.cpp
EMBED_CAPI_LINK_LIBS
MLIRCAPIRegisterEverything
CONCRETELANGCAPIFHE
CONCRETELANGCAPIFHELINALG
CONCRETELANGCAPITRACING
CONCRETELANGPySupport)
# ######################################################################################################################
@@ -48,9 +52,6 @@ declare_mlir_python_sources(
concrete/compiler/compilation_context.py
concrete/compiler/compilation_feedback.py
concrete/compiler/compilation_options.py
concrete/compiler/jit_compilation_result.py
concrete/compiler/jit_support.py
concrete/compiler/jit_lambda.py
concrete/compiler/key_set_cache.py
concrete/compiler/key_set.py
concrete/compiler/lambda_argument.py

View File

@@ -4,10 +4,13 @@
// for license information.
#include "concretelang/Bindings/Python/CompilerAPIModule.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/Bindings/Python/CompilerEngine.h"
#include "concretelang/ClientLib/ClientLib.h"
#include "concretelang/Common/Compat.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Keysets.h"
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
#include "concretelang/Support/JITSupport.h"
#include "concretelang/Support/Jit.h"
#include "concretelang/Support/logging.h"
#include <llvm/Support/Debug.h>
#include <mlir-c/Bindings/Python/Interop.h>
@@ -24,7 +27,6 @@
#include <string>
using mlir::concretelang::CompilationOptions;
using mlir::concretelang::JITSupport;
using mlir::concretelang::LambdaArgument;
class SignalGuard {
@@ -75,7 +77,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](std::string funcname) { return CompilationOptions(funcname); }))
.def("set_funcname",
[](CompilationOptions &options, std::string funcname) {
options.clientParametersFuncName = funcname;
options.mainFuncName = funcname;
})
.def("set_verify_diagnostics",
[](CompilationOptions &options, bool b) {
@@ -207,11 +209,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
"memory_usage_per_location",
&mlir::concretelang::CompilationFeedback::memoryUsagePerLoc);
pybind11::class_<mlir::concretelang::JitCompilationResult>(
m, "JITCompilationResult");
pybind11::class_<mlir::concretelang::JITLambda,
std::shared_ptr<mlir::concretelang::JITLambda>>(m,
"JITLambda");
pybind11::class_<mlir::concretelang::CompilationContext,
std::shared_ptr<mlir::concretelang::CompilationContext>>(
m, "CompilationContext")
@@ -224,51 +221,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return pybind11::reinterpret_steal<pybind11::object>(
mlirPythonContextToCapsule(wrap(mlirCtx)));
});
pybind11::class_<JITSupport_Py>(m, "JITSupport")
.def(pybind11::init([](std::string runtimeLibPath) {
return jit_support(runtimeLibPath);
}))
.def("compile",
[](JITSupport_Py &support, std::string mlir_program,
CompilationOptions options) {
SignalGuard signalGuard;
return jit_compile(support, mlir_program.c_str(), options);
})
.def("compile",
[](JITSupport_Py &support, pybind11::object mlir_module,
CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
SignalGuard signalGuard;
return jit_compile_module(
support,
unwrap(mlirPythonCapsuleToModule(mlir_module.ptr())).clone(),
options, cctx);
})
.def("load_client_parameters",
[](JITSupport_Py &support,
mlir::concretelang::JitCompilationResult &result) {
return jit_load_client_parameters(support, result);
})
.def("load_compilation_feedback",
[](JITSupport_Py &support,
mlir::concretelang::JitCompilationResult &result) {
return jit_load_compilation_feedback(support, result);
})
.def(
"load_server_lambda",
[](JITSupport_Py &support,
mlir::concretelang::JitCompilationResult &result) {
return jit_load_server_lambda(support, result);
},
pybind11::return_value_policy::reference)
.def("server_call",
[](JITSupport_Py &support, concretelang::JITLambda &lambda,
clientlib::PublicArguments &publicArguments,
clientlib::EvaluationKeys &evaluationKeys) {
SignalGuard signalGuard;
return jit_server_call(support, lambda, publicArguments,
evaluationKeys);
});
pybind11::class_<mlir::concretelang::LibraryCompilationResult>(
m, "LibraryCompilationResult")
@@ -278,7 +230,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
funcname,
};
}));
pybind11::class_<concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
pybind11::class_<::concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
pybind11::class_<LibrarySupport_Py>(m, "LibrarySupport")
.def(pybind11::init(
[](std::string outputPath, std::string runtimeLibraryPath,
@@ -319,21 +271,24 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(
"load_server_lambda",
[](LibrarySupport_Py &support,
mlir::concretelang::LibraryCompilationResult &result) {
return library_load_server_lambda(support, result);
mlir::concretelang::LibraryCompilationResult &result,
bool useSimulation) {
return library_load_server_lambda(support, result, useSimulation);
},
pybind11::return_value_policy::reference)
.def("server_call",
[](LibrarySupport_Py &support, serverlib::ServerLambda lambda,
clientlib::PublicArguments &publicArguments,
clientlib::EvaluationKeys &evaluationKeys) {
[](LibrarySupport_Py &support,
::concretelang::serverlib::ServerLambda lambda,
::concretelang::clientlib::PublicArguments &publicArguments,
::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
SignalGuard signalGuard;
return library_server_call(support, lambda, publicArguments,
evaluationKeys);
})
.def("simulate",
[](LibrarySupport_Py &support, serverlib::ServerLambda lambda,
clientlib::PublicArguments &publicArguments) {
[](LibrarySupport_Py &support,
::concretelang::serverlib::ServerLambda lambda,
::concretelang::clientlib::PublicArguments &publicArguments) {
pybind11::gil_scoped_release release;
return library_simulate(support, lambda, publicArguments);
})
@@ -341,8 +296,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](LibrarySupport_Py &support) {
return library_get_shared_lib_path(support);
})
.def("get_client_parameters_path", [](LibrarySupport_Py &support) {
return library_get_client_parameters_path(support);
.def("get_program_info_path", [](LibrarySupport_Py &support) {
return library_get_program_info_path(support);
});
class ClientSupport {};
@@ -350,120 +305,165 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(pybind11::init())
.def_static(
"key_set",
[](clientlib::ClientParameters clientParameters,
clientlib::KeySetCache *cache, uint64_t seedMsb,
[](::concretelang::clientlib::ClientParameters clientParameters,
::concretelang::clientlib::KeySetCache *cache, uint64_t seedMsb,
uint64_t seedLsb) {
SignalGuard signalGuard;
auto optCache = cache == nullptr
? std::nullopt
: std::optional<clientlib::KeySetCache>(*cache);
auto optCache =
cache == nullptr
? std::nullopt
: std::optional<::concretelang::clientlib::KeySetCache>(
*cache);
return key_set(clientParameters, optCache, seedMsb, seedLsb);
},
pybind11::arg().none(false), pybind11::arg().none(true),
pybind11::arg("seedMsb") = 0, pybind11::arg("seedLsb") = 0)
.def_static("encrypt_arguments",
[](clientlib::ClientParameters clientParameters,
clientlib::KeySet &keySet,
std::vector<lambdaArgument> args) {
std::vector<mlir::concretelang::LambdaArgument *> argsRef;
for (auto i = 0u; i < args.size(); i++) {
argsRef.push_back(args[i].ptr.get());
}
return encrypt_arguments(clientParameters, keySet, argsRef);
})
.def_static("decrypt_result", [](clientlib::KeySet &keySet,
clientlib::PublicResult &publicResult) {
return decrypt_result(keySet, publicResult);
});
pybind11::class_<clientlib::KeySetCache>(m, "KeySetCache")
.def_static(
"encrypt_arguments",
[](::concretelang::clientlib::ClientParameters clientParameters,
::concretelang::clientlib::KeySet &keySet,
std::vector<lambdaArgument> args) {
std::vector<mlir::concretelang::LambdaArgument *> argsRef;
for (auto i = 0u; i < args.size(); i++) {
argsRef.push_back(args[i].ptr.get());
}
return encrypt_arguments(clientParameters, keySet, argsRef);
})
.def_static(
"decrypt_result",
[](::concretelang::clientlib::ClientParameters clientParameters,
::concretelang::clientlib::KeySet &keySet,
::concretelang::clientlib::PublicResult &publicResult) {
return decrypt_result(clientParameters, keySet, publicResult);
});
pybind11::class_<::concretelang::clientlib::KeySetCache>(m, "KeySetCache")
.def(pybind11::init<std::string &>());
pybind11::class_<::concretelang::clientlib::LweSecretKeyParam>(
m, "LweSecretKeyParam")
.def_readonly("dimension",
&::concretelang::clientlib::LweSecretKeyParam::dimension);
.def("dimension", [](::concretelang::clientlib::LweSecretKeyParam &key) {
return key.info.asReader().getParams().getLweDimension();
});
pybind11::class_<::concretelang::clientlib::BootstrapKeyParam>(
m, "BootstrapKeyParam")
.def_readonly(
"input_secret_key_id",
&::concretelang::clientlib::BootstrapKeyParam::inputSecretKeyID)
.def_readonly(
"output_secret_key_id",
&::concretelang::clientlib::BootstrapKeyParam::outputSecretKeyID)
.def_readonly("level",
&::concretelang::clientlib::BootstrapKeyParam::level)
.def_readonly("base_log",
&::concretelang::clientlib::BootstrapKeyParam::baseLog)
.def_readonly(
"glwe_dimension",
&::concretelang::clientlib::BootstrapKeyParam::glweDimension)
.def_readonly("variance",
&::concretelang::clientlib::BootstrapKeyParam::variance)
.def_readonly(
"polynomial_size",
&::concretelang::clientlib::BootstrapKeyParam::polynomialSize)
.def_readonly(
"input_lwe_dimension",
&::concretelang::clientlib::BootstrapKeyParam::inputLweDimension);
.def("input_secret_key_id",
[](::concretelang::clientlib::BootstrapKeyParam &key) {
return key.info.asReader().getInputId();
})
.def("output_secret_key_id",
[](::concretelang::clientlib::BootstrapKeyParam &key) {
return key.info.asReader().getOutputId();
})
.def("level",
[](::concretelang::clientlib::BootstrapKeyParam &key) {
return key.info.asReader().getParams().getLevelCount();
})
.def("base_log",
[](::concretelang::clientlib::BootstrapKeyParam &key) {
return key.info.asReader().getParams().getBaseLog();
})
.def("glwe_dimension",
[](::concretelang::clientlib::BootstrapKeyParam &key) {
return key.info.asReader().getParams().getGlweDimension();
})
.def("variance",
[](::concretelang::clientlib::BootstrapKeyParam &key) {
return key.info.asReader().getParams().getVariance();
})
.def("polynomial_size",
[](::concretelang::clientlib::BootstrapKeyParam &key) {
return key.info.asReader().getParams().getPolynomialSize();
})
.def("input_lwe_dimension",
[](::concretelang::clientlib::BootstrapKeyParam &key) {
return key.info.asReader().getParams().getInputLweDimension();
});
pybind11::class_<::concretelang::clientlib::KeyswitchKeyParam>(
m, "KeyswitchKeyParam")
.def_readonly(
"input_secret_key_id",
&::concretelang::clientlib::KeyswitchKeyParam::inputSecretKeyID)
.def_readonly(
"output_secret_key_id",
&::concretelang::clientlib::KeyswitchKeyParam::outputSecretKeyID)
.def_readonly("level",
&::concretelang::clientlib::KeyswitchKeyParam::level)
.def_readonly("base_log",
&::concretelang::clientlib::KeyswitchKeyParam::baseLog)
.def_readonly("variance",
&::concretelang::clientlib::KeyswitchKeyParam::variance);
.def("input_secret_key_id",
[](::concretelang::clientlib::KeyswitchKeyParam &key) {
return key.info.asReader().getInputId();
})
.def("output_secret_key_id",
[](::concretelang::clientlib::KeyswitchKeyParam &key) {
return key.info.asReader().getOutputId();
})
.def("level",
[](::concretelang::clientlib::KeyswitchKeyParam &key) {
return key.info.asReader().getParams().getLevelCount();
})
.def("base_log",
[](::concretelang::clientlib::KeyswitchKeyParam &key) {
return key.info.asReader().getParams().getBaseLog();
})
.def("variance", [](::concretelang::clientlib::KeyswitchKeyParam &key) {
return key.info.asReader().getParams().getVariance();
});
pybind11::class_<::concretelang::clientlib::PackingKeyswitchKeyParam>(
m, "PackingKeyswitchKeyParam")
.def_readonly("input_secret_key_id",
&::concretelang::clientlib::PackingKeyswitchKeyParam::
inputSecretKeyID)
.def_readonly("output_secret_key_id",
&::concretelang::clientlib::PackingKeyswitchKeyParam::
outputSecretKeyID)
.def_readonly("level",
&::concretelang::clientlib::PackingKeyswitchKeyParam::level)
.def_readonly(
"base_log",
&::concretelang::clientlib::PackingKeyswitchKeyParam::baseLog)
.def_readonly(
"glwe_dimension",
&::concretelang::clientlib::PackingKeyswitchKeyParam::glweDimension)
.def_readonly(
"polynomial_size",
&::concretelang::clientlib::PackingKeyswitchKeyParam::polynomialSize)
.def_readonly("input_lwe_dimension",
&::concretelang::clientlib::PackingKeyswitchKeyParam::
inputLweDimension)
.def_readonly(
"variance",
&::concretelang::clientlib::PackingKeyswitchKeyParam::variance);
.def("input_secret_key_id",
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
return key.info.asReader().getInputId();
})
.def("output_secret_key_id",
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
return key.info.asReader().getOutputId();
})
.def("level",
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
return key.info.asReader().getParams().getLevelCount();
})
.def("base_log",
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
return key.info.asReader().getParams().getBaseLog();
})
.def("glwe_dimension",
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
return key.info.asReader().getParams().getGlweDimension();
})
.def("polynomial_size",
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
return key.info.asReader().getParams().getPolynomialSize();
})
.def("input_lwe_dimension",
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
return key.info.asReader().getParams().getInputLweDimension();
})
.def("variance",
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
return key.info.asReader().getParams().getVariance();
});
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters")
pybind11::class_<::concretelang::clientlib::ClientParameters>(
m, "ClientParameters")
.def_static("deserialize",
[](const pybind11::bytes &buffer) {
return clientParametersUnserialize(buffer);
})
.def("serialize",
[](mlir::concretelang::ClientParameters &clientParameters) {
[](::concretelang::clientlib::ClientParameters &clientParameters) {
return pybind11::bytes(
clientParametersSerialize(clientParameters));
})
.def("output_signs",
[](mlir::concretelang::ClientParameters &clientParameters) {
[](::concretelang::clientlib::ClientParameters &clientParameters) {
std::vector<bool> result;
for (auto output : clientParameters.outputs) {
if (output.encryption.has_value()) {
result.push_back(output.encryption.value().encoding.isSigned);
for (auto output : clientParameters.programInfo.asReader()
.getCircuits()[0]
.getOutputs()) {
if (output.getTypeInfo().hasLweCiphertext() &&
output.getTypeInfo()
.getLweCiphertext()
.getEncoding()
.hasInteger()) {
result.push_back(output.getTypeInfo()
.getLweCiphertext()
.getEncoding()
.getInteger()
.getIsSigned());
} else {
result.push_back(true);
}
@@ -471,11 +471,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return result;
})
.def("input_signs",
[](mlir::concretelang::ClientParameters &clientParameters) {
[](::concretelang::clientlib::ClientParameters &clientParameters) {
std::vector<bool> result;
for (auto input : clientParameters.inputs) {
if (input.encryption.has_value()) {
result.push_back(input.encryption.value().encoding.isSigned);
for (auto input : clientParameters.programInfo.asReader()
.getCircuits()[0]
.getInputs()) {
if (input.getTypeInfo().hasLweCiphertext() &&
input.getTypeInfo()
.getLweCiphertext()
.getEncoding()
.hasInteger()) {
result.push_back(input.getTypeInfo()
.getLweCiphertext()
.getEncoding()
.getInteger()
.getIsSigned());
} else {
result.push_back(true);
}
@@ -483,246 +493,244 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
return result;
})
.def_readonly("secret_keys",
&mlir::concretelang::ClientParameters::secretKeys)
&::concretelang::clientlib::ClientParameters::secretKeys)
.def_readonly("bootstrap_keys",
&mlir::concretelang::ClientParameters::bootstrapKeys)
&::concretelang::clientlib::ClientParameters::bootstrapKeys)
.def_readonly("keyswitch_keys",
&mlir::concretelang::ClientParameters::keyswitchKeys)
&::concretelang::clientlib::ClientParameters::keyswitchKeys)
.def_readonly(
"packing_keyswitch_keys",
&mlir::concretelang::ClientParameters::packingKeyswitchKeys);
&::concretelang::clientlib::ClientParameters::packingKeyswitchKeys);
pybind11::class_<clientlib::KeySet>(m, "KeySet")
pybind11::class_<::concretelang::clientlib::KeySet>(m, "KeySet")
.def_static("deserialize",
[](const pybind11::bytes &buffer) {
std::unique_ptr<KeySet> result = keySetUnserialize(buffer);
std::unique_ptr<::concretelang::clientlib::KeySet> result =
keySetUnserialize(buffer);
return result;
})
.def("serialize",
[](clientlib::KeySet &keySet) {
[](::concretelang::clientlib::KeySet &keySet) {
return pybind11::bytes(keySetSerialize(keySet));
})
.def("client_parameters",
[](clientlib::KeySet &keySet) { return keySet.clientParameters(); })
.def("get_evaluation_keys",
[](clientlib::KeySet &keySet) { return keySet.evaluationKeys(); });
[](::concretelang::clientlib::KeySet &keySet) {
return ::concretelang::clientlib::EvaluationKeys{
keySet.keyset.server};
});
pybind11::class_<clientlib::SharedScalarOrTensorData>(m, "Value")
pybind11::class_<::concretelang::clientlib::SharedScalarOrTensorData>(m,
"Value")
.def_static("deserialize",
[](const pybind11::bytes &buffer) {
return valueUnserialize(buffer);
})
.def("serialize", [](const clientlib::SharedScalarOrTensorData &value) {
return pybind11::bytes(valueSerialize(value));
});
.def(
"serialize",
[](const ::concretelang::clientlib::SharedScalarOrTensorData &value) {
return pybind11::bytes(valueSerialize(value));
});
pybind11::class_<clientlib::ValueExporter>(m, "ValueExporter")
.def_static("create",
[](clientlib::KeySet &keySet,
mlir::concretelang::ClientParameters &clientParameters) {
return clientlib::ValueExporter(keySet, clientParameters);
})
pybind11::class_<::concretelang::clientlib::ValueExporter>(m, "ValueExporter")
.def_static(
"create",
[](::concretelang::clientlib::KeySet &keySet,
::concretelang::clientlib::ClientParameters &clientParameters) {
return createValueExporter(keySet, clientParameters);
})
.def("export_scalar",
[](clientlib::ValueExporter &exporter, size_t position,
int64_t value) {
[](::concretelang::clientlib::ValueExporter &exporter,
size_t position, int64_t value) {
SignalGuard signalGuard;
outcome::checked<clientlib::ScalarOrTensorData, StringError>
result = exporter.exportValue(value, position);
auto info = exporter.circuit.getCircuitInfo()
.asReader()
.getInputs()[position];
auto typeTransformer = getPythonTypeTransformer(info);
auto result = exporter.circuit.prepareInput(
typeTransformer({Tensor<int64_t>(value)}), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return clientlib::SharedScalarOrTensorData(
std::move(result.value()));
return ::concretelang::clientlib::SharedScalarOrTensorData{
result.value()};
})
.def("export_tensor", [](clientlib::ValueExporter &exporter,
.def("export_tensor", [](::concretelang::clientlib::ValueExporter
&exporter,
size_t position, std::vector<int64_t> values,
std::vector<int64_t> shape) {
SignalGuard signalGuard;
outcome::checked<clientlib::ScalarOrTensorData, StringError> result =
exporter.exportValue(values.data(), shape, position);
std::vector<size_t> dimensions(shape.begin(), shape.end());
auto info =
exporter.circuit.getCircuitInfo().asReader().getInputs()[position];
auto typeTransformer = getPythonTypeTransformer(info);
auto result = exporter.circuit.prepareInput(
typeTransformer({Tensor<int64_t>(values, dimensions)}), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return clientlib::SharedScalarOrTensorData(std::move(result.value()));
return ::concretelang::clientlib::SharedScalarOrTensorData{
result.value()};
});
pybind11::class_<clientlib::SimulatedValueExporter>(m,
"SimulatedValueExporter")
.def_static("create",
[](mlir::concretelang::ClientParameters &clientParameters) {
return clientlib::SimulatedValueExporter(clientParameters);
})
pybind11::class_<::concretelang::clientlib::SimulatedValueExporter>(
m, "SimulatedValueExporter")
.def_static(
"create",
[](::concretelang::clientlib::ClientParameters &clientParameters) {
return createSimulatedValueExporter(clientParameters);
})
.def("export_scalar",
[](clientlib::SimulatedValueExporter &exporter, size_t position,
int64_t value) {
outcome::checked<clientlib::ScalarOrTensorData, StringError>
result = exporter.exportValue(value, position);
[](::concretelang::clientlib::SimulatedValueExporter &exporter,
size_t position, int64_t value) {
SignalGuard signalGuard;
auto info = exporter.circuit.getCircuitInfo()
.asReader()
.getInputs()[position];
auto typeTransformer = getPythonTypeTransformer(info);
auto result = exporter.circuit.prepareInput(
typeTransformer({Tensor<int64_t>(value)}), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return clientlib::SharedScalarOrTensorData(
std::move(result.value()));
return ::concretelang::clientlib::SharedScalarOrTensorData{
result.value()};
})
.def("export_tensor", [](clientlib::SimulatedValueExporter &exporter,
.def("export_tensor", [](::concretelang::clientlib::SimulatedValueExporter
&exporter,
size_t position, std::vector<int64_t> values,
std::vector<int64_t> shape) {
outcome::checked<clientlib::ScalarOrTensorData, StringError> result =
exporter.exportValue(values.data(), shape, position);
SignalGuard signalGuard;
std::vector<size_t> dimensions(shape.begin(), shape.end());
auto info =
exporter.circuit.getCircuitInfo().asReader().getInputs()[position];
auto typeTransformer = getPythonTypeTransformer(info);
auto result = exporter.circuit.prepareInput(
typeTransformer({Tensor<int64_t>(values, dimensions)}), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return clientlib::SharedScalarOrTensorData(std::move(result.value()));
return ::concretelang::clientlib::SharedScalarOrTensorData{
result.value()};
});
pybind11::class_<clientlib::ValueDecrypter>(m, "ValueDecrypter")
.def_static("create",
[](clientlib::KeySet &keySet,
mlir::concretelang::ClientParameters &clientParameters) {
return clientlib::ValueDecrypter(keySet, clientParameters);
})
.def("get_shape",
[](clientlib::ValueDecrypter &decrypter, size_t position) {
outcome::checked<std::vector<int64_t>, StringError> result =
decrypter.getShape(position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
})
.def("decrypt_scalar",
[](clientlib::ValueDecrypter &decrypter, size_t position,
clientlib::SharedScalarOrTensorData &value) {
pybind11::class_<::concretelang::clientlib::ValueDecrypter>(m,
"ValueDecrypter")
.def_static(
"create",
[](::concretelang::clientlib::KeySet &keySet,
::concretelang::clientlib::ClientParameters &clientParameters) {
return createValueDecrypter(keySet, clientParameters);
})
.def("decrypt",
[](::concretelang::clientlib::ValueDecrypter &decrypter,
size_t position,
::concretelang::clientlib::SharedScalarOrTensorData &value) {
SignalGuard signalGuard;
outcome::checked<int64_t, StringError> result =
decrypter.decrypt<int64_t>(value.get(), position);
auto result =
decrypter.circuit.processOutput(value.value, position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
})
.def("decrypt_tensor",
[](clientlib::ValueDecrypter &decrypter, size_t position,
clientlib::SharedScalarOrTensorData &value) {
SignalGuard signalGuard;
outcome::checked<std::vector<int64_t>, StringError> result =
decrypter.decryptTensor<int64_t>(value.get(), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
return lambdaArgument{
std::make_shared<mlir::concretelang::LambdaArgument>(
mlir::concretelang::LambdaArgument{result.value()})};
});
pybind11::class_<clientlib::SimulatedValueDecrypter>(
pybind11::class_<::concretelang::clientlib::SimulatedValueDecrypter>(
m, "SimulatedValueDecrypter")
.def_static("create",
[](mlir::concretelang::ClientParameters &clientParameters) {
return clientlib::SimulatedValueDecrypter(clientParameters);
})
.def("get_shape",
[](clientlib::SimulatedValueDecrypter &decrypter, size_t position) {
outcome::checked<std::vector<int64_t>, StringError> result =
decrypter.getShape(position);
.def_static(
"create",
[](::concretelang::clientlib::ClientParameters &clientParameters) {
return createSimulatedValueDecrypter(clientParameters);
})
.def("decrypt",
[](::concretelang::clientlib::SimulatedValueDecrypter &decrypter,
size_t position,
::concretelang::clientlib::SharedScalarOrTensorData &value) {
SignalGuard signalGuard;
auto result =
decrypter.circuit.processOutput(value.value, position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
})
.def("decrypt_scalar",
[](clientlib::SimulatedValueDecrypter &decrypter, size_t position,
clientlib::SharedScalarOrTensorData &value) {
outcome::checked<int64_t, StringError> result =
decrypter.decrypt<int64_t>(value.get(), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
})
.def("decrypt_tensor",
[](clientlib::SimulatedValueDecrypter &decrypter, size_t position,
clientlib::SharedScalarOrTensorData &value) {
outcome::checked<std::vector<int64_t>, StringError> result =
decrypter.decryptTensor<int64_t>(value.get(), position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
}
return result.value();
return lambdaArgument{
std::make_shared<mlir::concretelang::LambdaArgument>(
mlir::concretelang::LambdaArgument{result.value()})};
});
pybind11::class_<clientlib::PublicArguments,
std::unique_ptr<clientlib::PublicArguments>>(
pybind11::class_<::concretelang::clientlib::PublicArguments,
std::unique_ptr<::concretelang::clientlib::PublicArguments>>(
m, "PublicArguments")
.def_static(
"create",
[](const mlir::concretelang::ClientParameters &clientParameters,
std::vector<clientlib::SharedScalarOrTensorData> &buffers) {
return clientlib::PublicArguments(clientParameters, buffers);
[](const ::concretelang::clientlib::ClientParameters
&clientParameters,
std::vector<::concretelang::clientlib::SharedScalarOrTensorData>
&buffers) {
std::vector<TransportValue> vals;
for (auto buf : buffers) {
vals.push_back(buf.value);
}
return ::concretelang::clientlib::PublicArguments{vals};
})
.def_static(
"deserialize",
[](::concretelang::clientlib::ClientParameters &clientParameters,
const pybind11::bytes &buffer) {
return publicArgumentsUnserialize(clientParameters, buffer);
})
.def_static("deserialize",
[](mlir::concretelang::ClientParameters &clientParameters,
const pybind11::bytes &buffer) {
return publicArgumentsUnserialize(clientParameters, buffer);
})
.def("serialize", [](clientlib::PublicArguments &publicArgument) {
return pybind11::bytes(publicArgumentsSerialize(publicArgument));
});
pybind11::class_<clientlib::PublicResult>(m, "PublicResult")
.def_static("deserialize",
[](mlir::concretelang::ClientParameters &clientParameters,
const pybind11::bytes &buffer) {
return publicResultUnserialize(clientParameters, buffer);
})
.def("serialize",
[](clientlib::PublicResult &publicResult) {
[](::concretelang::clientlib::PublicArguments &publicArgument) {
return pybind11::bytes(publicArgumentsSerialize(publicArgument));
});
pybind11::class_<::concretelang::clientlib::PublicResult>(m, "PublicResult")
.def_static(
"deserialize",
[](::concretelang::clientlib::ClientParameters &clientParameters,
const pybind11::bytes &buffer) {
return publicResultUnserialize(clientParameters, buffer);
})
.def("serialize",
[](::concretelang::clientlib::PublicResult &publicResult) {
return pybind11::bytes(publicResultSerialize(publicResult));
})
.def("n_values",
[](const clientlib::PublicResult &publicResult) {
return publicResult.buffers.size();
[](const ::concretelang::clientlib::PublicResult &publicResult) {
return publicResult.values.size();
})
.def("get_value",
[](clientlib::PublicResult &publicResult, size_t position) {
outcome::checked<clientlib::SharedScalarOrTensorData, StringError>
result = publicResult.getValue(position);
if (result.has_error()) {
throw std::runtime_error(result.error().mesg);
[](::concretelang::clientlib::PublicResult &publicResult,
size_t position) {
if (position >= publicResult.values.size()) {
throw std::runtime_error("Failed to get public result value.");
}
return result.value();
return ::concretelang::clientlib::SharedScalarOrTensorData{
publicResult.values[position]};
});
pybind11::class_<clientlib::EvaluationKeys>(m, "EvaluationKeys")
pybind11::class_<::concretelang::clientlib::EvaluationKeys>(m,
"EvaluationKeys")
.def_static("deserialize",
[](const pybind11::bytes &buffer) {
return evaluationKeysUnserialize(buffer);
})
.def("serialize", [](clientlib::EvaluationKeys &evaluationKeys) {
return pybind11::bytes(evaluationKeysSerialize(evaluationKeys));
});
.def("serialize",
[](::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
return pybind11::bytes(evaluationKeysSerialize(evaluationKeys));
});
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
.def_static("from_tensor_u8",

View File

@@ -4,79 +4,19 @@
// for license information.
#include "llvm/ADT/SmallString.h"
#include <cstdint>
#include <memory>
#include <stdexcept>
#include "capnp/message.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/Bindings/Python/CompilerEngine.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Compat.h"
#include "concretelang/Common/Keysets.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Common/Values.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/JITSupport.h"
#include "concretelang/Support/Jit.h"
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
auto VARNAME = EXPECTED; \
if (auto err = VARNAME.takeError()) { \
throw std::runtime_error(llvm::toString(std::move(err))); \
}
// JIT Support bindings ///////////////////////////////////////////////////////
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath) {
auto opt = runtimeLibPath.empty()
? std::nullopt
: std::optional<std::string>(runtimeLibPath);
return JITSupport_Py{mlir::concretelang::JITSupport(opt)};
}
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile(JITSupport_Py support, const char *module,
mlir::concretelang::CompilationOptions options) {
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, options));
return std::move(*compilationResult);
}
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
jit_compile_module(
JITSupport_Py support, mlir::ModuleOp module,
mlir::concretelang::CompilationOptions options,
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, cctx, options));
return std::move(*compilationResult);
}
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
jit_load_client_parameters(JITSupport_Py support,
mlir::concretelang::JitCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
support.support.loadClientParameters(result));
return *clientParameters;
}
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
jit_load_compilation_feedback(
JITSupport_Py support, mlir::concretelang::JitCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
support.support.loadCompilationFeedback(result));
return *compilationFeedback;
}
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
jit_load_server_lambda(JITSupport_Py support,
mlir::concretelang::JitCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
support.support.loadServerLambda(result));
return *serverLambda;
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda,
concretelang::clientlib::PublicArguments &args,
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys));
return std::move(*publicResult);
}
// Library Support bindings ///////////////////////////////////////////////////
MLIR_CAPI_EXPORTED LibrarySupport_Py
@@ -86,15 +26,17 @@ library_support(const char *outputPath, const char *runtimeLibraryPath,
bool generateCppHeader) {
return LibrarySupport_Py{mlir::concretelang::LibrarySupport(
outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib,
generateClientParameters, generateCompilationFeedback,
generateCppHeader)};
generateClientParameters, generateCompilationFeedback)};
}
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
library_compile(LibrarySupport_Py support, const char *module,
mlir::concretelang::CompilationOptions options) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(module),
llvm::SMLoc());
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
support.support.compile(module, options));
support.support.compile(sm, options));
return std::move(*compilationResult);
}
@@ -108,7 +50,7 @@ library_compile_module(
return std::move(*compilationResult);
}
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
MLIR_CAPI_EXPORTED concretelang::clientlib::ClientParameters
library_load_client_parameters(
LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result) {
@@ -127,11 +69,11 @@ library_load_compilation_feedback(
}
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
library_load_server_lambda(
LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result) {
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
support.support.loadServerLambda(result));
library_load_server_lambda(LibrarySupport_Py support,
mlir::concretelang::LibraryCompilationResult &result,
bool useSimulation) {
GET_OR_THROW_LLVM_EXPECTED(
serverLambda, support.support.loadServerLambda(result, useSimulation));
return *serverLambda;
}
@@ -160,8 +102,8 @@ library_get_shared_lib_path(LibrarySupport_Py support) {
}
MLIR_CAPI_EXPORTED std::string
library_get_client_parameters_path(LibrarySupport_Py support) {
return support.support.getClientParametersPath();
library_get_program_info_path(LibrarySupport_Py support) {
return support.support.getProgramInfoPath();
}
// Client Support bindings ///////////////////////////////////////////////////
@@ -170,157 +112,303 @@ MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
key_set(concretelang::clientlib::ClientParameters clientParameters,
std::optional<concretelang::clientlib::KeySetCache> cache,
uint64_t seedMsb, uint64_t seedLsb) {
GET_OR_THROW_LLVM_EXPECTED(
ks, (mlir::concretelang::LambdaSupport<int, int>::keySet(
clientParameters, cache, seedMsb, seedLsb)));
return std::move(*ks);
if (cache.has_value()) {
GET_OR_THROW_RESULT(Keyset keyset,
(*cache).keysetCache.getKeyset(
clientParameters.programInfo.asReader().getKeyset(),
seedMsb, seedLsb));
concretelang::clientlib::KeySet output{keyset};
return std::make_unique<concretelang::clientlib::KeySet>(std::move(output));
} else {
__uint128_t seed = seedMsb;
seed <<= 64;
seed += seedLsb;
auto csprng = concretelang::csprng::ConcreteCSPRNG(seed);
auto keyset =
Keyset(clientParameters.programInfo.asReader().getKeyset(), csprng);
concretelang::clientlib::KeySet output{keyset};
return std::make_unique<concretelang::clientlib::KeySet>(std::move(output));
}
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args) {
GET_OR_THROW_LLVM_EXPECTED(
publicArguments,
(mlir::concretelang::LambdaSupport<int, int>::exportArguments(
clientParameters, keySet, args)));
return std::move(*publicArguments);
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(), keySet.keyset.client,
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
false);
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto circuit = maybeProgram.value()
.getClientCircuit(clientParameters.programInfo.asReader()
.getCircuits()[0]
.getName())
.value();
std::vector<TransportValue> output;
for (size_t i = 0; i < args.size(); i++) {
auto info =
clientParameters.programInfo.asReader().getCircuits()[0].getInputs()[i];
auto typeTransformer = getPythonTypeTransformer(info);
auto input = typeTransformer(args[i]->value);
auto maybePrepared = circuit.prepareInput(input, i);
if (maybePrepared.has_failure()) {
throw std::runtime_error(maybePrepared.as_failure().error().mesg);
}
output.push_back(maybePrepared.value());
}
concretelang::clientlib::PublicArguments publicArgs{output};
return std::make_unique<concretelang::clientlib::PublicArguments>(
std::move(publicArgs));
}
MLIR_CAPI_EXPORTED lambdaArgument
decrypt_result(concretelang::clientlib::KeySet &keySet,
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::PublicResult &publicResult) {
GET_OR_THROW_LLVM_EXPECTED(
result, mlir::concretelang::typedResult<
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
keySet, publicResult));
lambdaArgument result_{std::move(*result)};
return result_;
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(), keySet.keyset.client,
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
false);
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
if (publicResult.values.size() != 1) {
throw std::runtime_error("Tried to decrypt with wrong arity.");
}
auto circuit = maybeProgram.value()
.getClientCircuit(clientParameters.programInfo.asReader()
.getCircuits()[0]
.getName())
.value();
auto maybeProcessed = circuit.processOutput(publicResult.values[0], 0);
if (maybeProcessed.has_failure()) {
throw std::runtime_error(maybeProcessed.as_failure().error().mesg);
}
mlir::concretelang::LambdaArgument out{maybeProcessed.value()};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
publicArgumentsUnserialize(
mlir::concretelang::ClientParameters &clientParameters,
concretelang::clientlib::ClientParameters &clientParameters,
const std::string &buffer) {
std::stringstream istream(buffer);
auto argsOrError = concretelang::clientlib::PublicArguments::unserialize(
clientParameters, istream);
if (!argsOrError) {
throw std::runtime_error(argsOrError.error().mesg);
auto publicArgumentsProto = Message<concreteprotocol::PublicArguments>();
if (publicArgumentsProto.readBinaryFromString(buffer).has_failure()) {
throw std::runtime_error("Failed to deserialize public arguments.");
}
return std::move(argsOrError.value());
std::vector<TransportValue> values;
for (auto arg : publicArgumentsProto.asReader().getArgs()) {
values.push_back(arg);
}
concretelang::clientlib::PublicArguments output{values};
return std::make_unique<concretelang::clientlib::PublicArguments>(
std::move(output));
}
MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize(
concretelang::clientlib::PublicArguments &publicArguments) {
std::ostringstream buffer(std::ios::binary);
auto voidOrError = publicArguments.serialize(buffer);
if (!voidOrError) {
throw std::runtime_error(voidOrError.error().mesg);
auto publicArgumentsProto = Message<concreteprotocol::PublicArguments>();
auto argBuilder =
publicArgumentsProto.asBuilder().initArgs(publicArguments.values.size());
for (size_t i = 0; i < publicArguments.values.size(); i++) {
argBuilder.setWithCaveats(i, publicArguments.values[i].asReader());
}
return buffer.str();
auto maybeBuffer = publicArgumentsProto.writeBinaryToString();
if (maybeBuffer.has_failure()) {
throw std::runtime_error("Failed to serialize public arguments.");
}
return maybeBuffer.value();
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
const std::string &buffer) {
std::stringstream istream(buffer);
auto publicResultOrError = concretelang::clientlib::PublicResult::unserialize(
clientParameters, istream);
if (!publicResultOrError) {
throw std::runtime_error(publicResultOrError.error().mesg);
publicResultUnserialize(
concretelang::clientlib::ClientParameters &clientParameters,
const std::string &buffer) {
auto publicResultsProto = Message<concreteprotocol::PublicResults>();
if (publicResultsProto.readBinaryFromString(buffer).has_failure()) {
throw std::runtime_error("Failed to deserialize public results.");
}
return std::move(publicResultOrError.value());
std::vector<TransportValue> values;
for (auto res : publicResultsProto.asReader().getResults()) {
values.push_back(res);
}
concretelang::clientlib::PublicResult output{values};
return std::make_unique<concretelang::clientlib::PublicResult>(
std::move(output));
}
MLIR_CAPI_EXPORTED std::string
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) {
std::ostringstream buffer(std::ios::binary);
auto voidOrError = publicResult.serialize(buffer);
if (!voidOrError) {
throw std::runtime_error(voidOrError.error().mesg);
std::string buffer;
auto publicResultsProto = Message<concreteprotocol::PublicResults>();
auto resBuilder =
publicResultsProto.asBuilder().initResults(publicResult.values.size());
for (size_t i = 0; i < publicResult.values.size(); i++) {
resBuilder.setWithCaveats(i, publicResult.values[i].asReader());
}
return buffer.str();
auto maybeBuffer = publicResultsProto.writeBinaryToString();
if (maybeBuffer.has_failure()) {
throw std::runtime_error("Failed to serialize public results.");
}
return maybeBuffer.value();
}
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
evaluationKeysUnserialize(const std::string &buffer) {
std::stringstream istream(buffer);
concretelang::clientlib::EvaluationKeys evaluationKeys =
concretelang::clientlib::readEvaluationKeys(istream);
if (istream.fail()) {
throw std::runtime_error("Cannot read evaluation keys");
auto serverKeysetProto = Message<concreteprotocol::ServerKeyset>();
auto maybeError = serverKeysetProto.readBinaryFromString(
buffer, capnp::ReaderOptions{7000000000, 64});
if (maybeError.has_failure()) {
throw std::runtime_error("Failed to deserialize server keyset." +
maybeError.as_failure().error().mesg);
}
return evaluationKeys;
auto serverKeyset =
concretelang::keysets::ServerKeyset::fromProto(serverKeysetProto);
concretelang::clientlib::EvaluationKeys output{serverKeyset};
return output;
}
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
std::ostringstream buffer(std::ios::binary);
concretelang::clientlib::operator<<(buffer, evaluationKeys);
return buffer.str();
auto serverKeysetProto = evaluationKeys.keyset.toProto();
auto maybeBuffer = serverKeysetProto.writeBinaryToString();
if (maybeBuffer.has_failure()) {
throw std::runtime_error("Failed to serialize evaluation keys.");
}
return maybeBuffer.value();
}
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
keySetUnserialize(const std::string &buffer) {
std::stringstream istream(buffer);
std::unique_ptr<concretelang::clientlib::KeySet> keySet =
concretelang::clientlib::readKeySet(istream);
if (istream.fail() || keySet.get() == nullptr) {
throw std::runtime_error("Cannot read key set");
auto keysetProto = Message<concreteprotocol::Keyset>();
auto maybeError = keysetProto.readBinaryFromString(
buffer, capnp::ReaderOptions{7000000000, 64});
if (maybeError.has_failure()) {
throw std::runtime_error("Failed to deserialize keyset." +
maybeError.as_failure().error().mesg);
}
return keySet;
auto keyset = concretelang::keysets::Keyset::fromProto(keysetProto);
concretelang::clientlib::KeySet output{keyset};
return std::make_unique<concretelang::clientlib::KeySet>(std::move(output));
}
MLIR_CAPI_EXPORTED std::string
keySetSerialize(concretelang::clientlib::KeySet &keySet) {
std::ostringstream buffer(std::ios::binary);
concretelang::clientlib::operator<<(buffer, keySet);
return buffer.str();
auto keysetProto = keySet.keyset.toProto();
auto maybeBuffer = keysetProto.writeBinaryToString();
if (maybeBuffer.has_failure()) {
throw std::runtime_error("Failed to serialize keys.");
}
return maybeBuffer.value();
}
MLIR_CAPI_EXPORTED concretelang::clientlib::SharedScalarOrTensorData
valueUnserialize(const std::string &buffer) {
std::stringstream istream(buffer);
auto value = concretelang::clientlib::unserializeScalarOrTensorData(istream);
if (istream.fail() || value.has_error()) {
throw std::runtime_error("Cannot read data");
auto inner = TransportValue();
if (inner.readBinaryFromString(buffer).has_failure()) {
throw std::runtime_error("Failed to deserialize Value");
}
return concretelang::clientlib::SharedScalarOrTensorData(
std::move(value.value()));
return {inner};
}
MLIR_CAPI_EXPORTED std::string
valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) {
std::ostringstream buffer(std::ios::binary);
serializeScalarOrTensorData(value.get(), buffer);
return buffer.str();
auto maybeString = value.value.writeBinaryToString();
if (maybeString.has_failure()) {
throw std::runtime_error("Failed to serialize Value");
}
return maybeString.value();
}
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
MLIR_CAPI_EXPORTED concretelang::clientlib::ValueExporter createValueExporter(
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::ClientParameters &clientParameters) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(), keySet.keyset.client,
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
false);
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto maybeCircuit = maybeProgram.value().getClientCircuit(
clientParameters.programInfo.asReader().getCircuits()[0].getName());
return ::concretelang::clientlib::ValueExporter{maybeCircuit.value()};
}
MLIR_CAPI_EXPORTED concretelang::clientlib::SimulatedValueExporter
createSimulatedValueExporter(
concretelang::clientlib::ClientParameters &clientParameters) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo, ::concretelang::keysets::ClientKeyset(),
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
true);
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto maybeCircuit = maybeProgram.value().getClientCircuit(
clientParameters.programInfo.asReader().getCircuits()[0].getName());
return ::concretelang::clientlib::SimulatedValueExporter{
maybeCircuit.value()};
}
MLIR_CAPI_EXPORTED concretelang::clientlib::ValueDecrypter createValueDecrypter(
concretelang::clientlib::KeySet &keySet,
concretelang::clientlib::ClientParameters &clientParameters) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(), keySet.keyset.client,
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
false);
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto maybeCircuit = maybeProgram.value().getClientCircuit(
clientParameters.programInfo.asReader().getCircuits()[0].getName());
return ::concretelang::clientlib::ValueDecrypter{maybeCircuit.value()};
}
MLIR_CAPI_EXPORTED concretelang::clientlib::SimulatedValueDecrypter
createSimulatedValueDecrypter(
concretelang::clientlib::ClientParameters &clientParameters) {
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
clientParameters.programInfo.asReader(),
::concretelang::keysets::ClientKeyset(),
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
true);
if (maybeProgram.has_failure()) {
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
}
auto maybeCircuit = maybeProgram.value().getClientCircuit(
clientParameters.programInfo.asReader().getCircuits()[0].getName());
return ::concretelang::clientlib::SimulatedValueDecrypter{
maybeCircuit.value()};
}
MLIR_CAPI_EXPORTED concretelang::clientlib::ClientParameters
clientParametersUnserialize(const std::string &json) {
GET_OR_THROW_LLVM_EXPECTED(
clientParams,
llvm::json::parse<mlir::concretelang::ClientParameters>(json));
return clientParams.get();
auto programInfo = Message<concreteprotocol::ProgramInfo>();
if (programInfo.readJsonFromString(json).has_failure()) {
throw std::runtime_error("Failed to deserialize client parameters");
}
return concretelang::clientlib::ClientParameters{programInfo, {}, {}, {}, {}};
}
MLIR_CAPI_EXPORTED std::string
clientParametersSerialize(mlir::concretelang::ClientParameters &params) {
llvm::json::Value value(params);
std::string jsonParams;
llvm::raw_string_ostream buffer(jsonParams);
buffer << value;
return jsonParams;
clientParametersSerialize(concretelang::clientlib::ClientParameters &params) {
auto maybeJson = params.programInfo.writeJsonToString();
if (maybeJson.has_failure()) {
throw std::runtime_error("Failed to serialize client parameters");
}
return maybeJson.value();
}
MLIR_CAPI_EXPORTED void terminateDataflowParallelization() { _dfr_terminate(); }
@@ -350,283 +438,180 @@ MLIR_CAPI_EXPORTED std::string roundTrip(const char *module) {
}
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
return lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int8_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int16_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int32_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int64_t>>>();
}
template <typename T, typename R>
MLIR_CAPI_EXPORTED std::vector<R> copyTensorLambdaArgumentTo64bitsvector(
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>> *tensor) {
auto numElements = tensor->getNumElements();
if (!numElements) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Couldn't get size of tensor: "
<< llvm::toString(std::move(numElements.takeError()));
throw std::runtime_error(os.str());
}
std::vector<R> res;
res.reserve(*numElements);
T *data = tensor->getValue();
for (size_t i = 0; i < *numElements; i++) {
res.push_back(data[i]);
}
return res;
return !lambda_arg.ptr->value.isScalar();
}
MLIR_CAPI_EXPORTED std::vector<uint64_t>
lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
if (!sizeOrErr) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Couldn't get size of tensor: "
<< llvm::toString(sizeOrErr.takeError());
throw std::runtime_error(os.str());
}
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
return data;
if (auto tensor = lambda_arg.ptr->value.getTensor<uint8_t>(); tensor) {
Tensor<uint64_t> out = (Tensor<uint64_t>)tensor.value();
return out.values;
} else if (auto tensor = lambda_arg.ptr->value.getTensor<uint16_t>();
tensor) {
Tensor<uint64_t> out = (Tensor<uint64_t>)tensor.value();
return out.values;
} else if (auto tensor = lambda_arg.ptr->value.getTensor<uint32_t>();
tensor) {
Tensor<uint64_t> out = (Tensor<uint64_t>)tensor.value();
return out.values;
} else if (auto tensor = lambda_arg.ptr->value.getTensor<uint64_t>();
tensor) {
return tensor.value().values;
} else {
throw std::invalid_argument(
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector<uint8_t, uint64_t>(arg);
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector<uint16_t, uint64_t>(arg);
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector<uint32_t, uint64_t>(arg);
}
throw std::invalid_argument(
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
}
MLIR_CAPI_EXPORTED std::vector<int64_t>
lambdaArgumentGetSignedTensorData(lambdaArgument &lambda_arg) {
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int64_t>>>()) {
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
if (!sizeOrErr) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Couldn't get size of tensor: "
<< llvm::toString(sizeOrErr.takeError());
throw std::runtime_error(os.str());
}
std::vector<int64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
return data;
if (auto tensor = lambda_arg.ptr->value.getTensor<int8_t>(); tensor) {
Tensor<int64_t> out = (Tensor<int64_t>)tensor.value();
return out.values;
} else if (auto tensor = lambda_arg.ptr->value.getTensor<int16_t>(); tensor) {
Tensor<int64_t> out = (Tensor<int64_t>)tensor.value();
return out.values;
} else if (auto tensor = lambda_arg.ptr->value.getTensor<int32_t>(); tensor) {
Tensor<int64_t> out = (Tensor<int64_t>)tensor.value();
return out.values;
} else if (auto tensor = lambda_arg.ptr->value.getTensor<int64_t>(); tensor) {
return tensor.value().values;
} else {
throw std::invalid_argument(
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int8_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector<int8_t, int64_t>(arg);
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int16_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector<int16_t, int64_t>(arg);
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int32_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector<int32_t, int64_t>(arg);
}
throw std::invalid_argument(
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
}
MLIR_CAPI_EXPORTED std::vector<int64_t>
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int8_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int16_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int32_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int64_t>>>()) {
return arg->getDimensions();
}
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<(u)int{8,16,32,64}_t>>");
std::vector<size_t> dims = lambda_arg.ptr->value.getDimensions();
return {dims.begin(), dims.end()};
}
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {
auto ptr = lambda_arg.ptr;
return ptr->isa<mlir::concretelang::IntLambdaArgument<uint64_t>>() ||
ptr->isa<mlir::concretelang::IntLambdaArgument<int64_t>>();
return lambda_arg.ptr->value.isScalar();
}
MLIR_CAPI_EXPORTED bool lambdaArgumentIsSigned(lambdaArgument &lambda_arg) {
auto ptr = lambda_arg.ptr;
return ptr->isa<mlir::concretelang::IntLambdaArgument<int8_t>>() ||
ptr->isa<mlir::concretelang::IntLambdaArgument<int16_t>>() ||
ptr->isa<mlir::concretelang::IntLambdaArgument<int32_t>>() ||
ptr->isa<mlir::concretelang::IntLambdaArgument<int64_t>>() ||
ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int8_t>>>() ||
ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int16_t>>>() ||
ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int32_t>>>() ||
ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int64_t>>>();
;
return lambda_arg.ptr->value.isSigned();
}
MLIR_CAPI_EXPORTED uint64_t
lambdaArgumentGetScalar(lambdaArgument &lambda_arg) {
mlir::concretelang::IntLambdaArgument<uint64_t> *arg =
lambda_arg.ptr
->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
if (arg == nullptr) {
if (lambda_arg.ptr->value.isScalar() &&
lambda_arg.ptr->value.hasElementType<uint64_t>()) {
return lambda_arg.ptr->value.getTensor<uint64_t>()->values[0];
} else {
throw std::invalid_argument("LambdaArgument isn't a scalar, should "
"be an IntLambdaArgument<uint64_t>");
}
return arg->getValue();
}
MLIR_CAPI_EXPORTED int64_t
lambdaArgumentGetSignedScalar(lambdaArgument &lambda_arg) {
mlir::concretelang::IntLambdaArgument<int64_t> *arg =
lambda_arg.ptr
->dyn_cast<mlir::concretelang::IntLambdaArgument<int64_t>>();
if (arg == nullptr) {
if (lambda_arg.ptr->value.isScalar() &&
lambda_arg.ptr->value.hasElementType<int64_t>()) {
return lambda_arg.ptr->value.getTensor<int64_t>()->values[0];
} else {
throw std::invalid_argument("LambdaArgument isn't a scalar, should "
"be an IntLambdaArgument<int64_t>");
}
return arg->getValue();
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
std::vector<uint8_t> data, std::vector<int64_t> dimensions) {
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
auto val = Value{((Tensor<int64_t>)Tensor<uint8_t>(data, dims))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>(data, dimensions)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorI8(
std::vector<int8_t> data, std::vector<int64_t> dimensions) {
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
auto val = Value{((Tensor<int64_t>)Tensor<int8_t>(data, dims))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int8_t>>>(data, dimensions)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16(
std::vector<uint16_t> data, std::vector<int64_t> dimensions) {
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
auto val = Value{((Tensor<int64_t>)Tensor<uint16_t>(data, dims))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>(data, dimensions)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorI16(
std::vector<int16_t> data, std::vector<int64_t> dimensions) {
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
auto val = Value{((Tensor<int64_t>)Tensor<int16_t>(data, dims))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int16_t>>>(data, dimensions)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32(
std::vector<uint32_t> data, std::vector<int64_t> dimensions) {
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
auto val = Value{((Tensor<int64_t>)Tensor<uint32_t>(data, dims))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>(data, dimensions)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorI32(
std::vector<int32_t> data, std::vector<int64_t> dimensions) {
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
auto val = Value{((Tensor<int64_t>)Tensor<int32_t>(data, dims))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int32_t>>>(data, dimensions)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64(
std::vector<uint64_t> data, std::vector<int64_t> dimensions) {
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
auto val = Value{((Tensor<int64_t>)Tensor<uint64_t>(data, dims))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>(data, dimensions)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorI64(
std::vector<int64_t> data, std::vector<int64_t> dimensions) {
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
auto val = Value{((Tensor<int64_t>)Tensor<int64_t>(data, dims))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument tensor_arg{
std::make_shared<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<int64_t>>>(data, dimensions)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return tensor_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) {
auto val = Value{((Tensor<int64_t>)Tensor<uint64_t>(scalar))};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument scalar_arg{
std::make_shared<mlir::concretelang::IntLambdaArgument<uint64_t>>(
scalar)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return scalar_arg;
}
MLIR_CAPI_EXPORTED lambdaArgument
lambdaArgumentFromSignedScalar(int64_t scalar) {
auto val = Value{Tensor<int64_t>(scalar)};
mlir::concretelang::LambdaArgument out{val};
lambdaArgument scalar_arg{
std::make_shared<mlir::concretelang::IntLambdaArgument<int64_t>>(scalar)};
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
return scalar_arg;
}

View File

@@ -3,16 +3,17 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang-c/Dialect/FHE.h"
#include "concretelang-c/Dialect/FHELinalg.h"
#include "concretelang-c/Dialect/Tracing.h"
#include "concretelang/Bindings/Python/CompilerAPIModule.h"
#include "concretelang/Bindings/Python/DialectModules.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
#include "concretelang/Support/Constants.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/IR.h"
#include "mlir-c/RegisterEverything.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/IR/DialectRegistry.h"
#include "llvm-c/ErrorHandling.h"
@@ -21,6 +22,19 @@
#include <pybind11/pybind11.h>
namespace py = pybind11;
using namespace mlir::concretelang::FHE;
using namespace mlir::concretelang::FHELinalg;
using namespace mlir::concretelang::Tracing;
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHE, fhe);
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHE, fhe, FHEDialect)
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg);
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg, FHELinalgDialect)
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TRACING, tracing);
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Tracing, tracing, TracingDialect)
PYBIND11_MODULE(_concretelang, m) {
m.doc() = "Concretelang Python Native Extension";
llvm::sys::PrintStackTraceOnErrorSignal(/*argv=*/"");

View File

@@ -3,11 +3,12 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang-c/Dialect/FHE.h"
#include "concretelang/Bindings/Python/DialectModules.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir/Bindings/Python/PybindAdaptors.h"
#include "mlir/CAPI/IR.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/raw_ostream.h"
@@ -17,18 +18,47 @@
#include <pybind11/stl.h>
using namespace mlir::concretelang;
using namespace mlir::concretelang::FHE;
using namespace mlir::python::adaptors;
typedef struct {
MlirType type;
bool isError;
} MlirTypeOrError;
template <typename T>
MlirTypeOrError IntegerTypeGetChecked(MlirContext ctx, unsigned width) {
MlirTypeOrError type = {{NULL}, false};
auto catchError = [&]() -> mlir::InFlightDiagnostic {
type.isError = true;
mlir::DiagnosticEngine &engine = unwrap(ctx)->getDiagEngine();
// The goal here is to make getChecked working, but we don't want the CAPI
// to stop execution due to an error, and leave the error handling logic to
// the user of the CAPI
return engine.emit(mlir::UnknownLoc::get(unwrap(ctx)),
mlir::DiagnosticSeverity::Warning);
};
T integerType = T::getChecked(catchError, unwrap(ctx), width);
if (type.isError) {
return type;
}
type.type = wrap(integerType);
return type;
}
/// Populate the fhe python module.
void mlir::concretelang::python::populateDialectFHESubmodule(
pybind11::module &m) {
m.doc() = "FHE dialect Python native extension";
mlir_type_subclass(m, "EncryptedIntegerType", fheTypeIsAnEncryptedIntegerType)
mlir_type_subclass(m, "EncryptedIntegerType",
[](MlirType type) {
return unwrap(type).isa<EncryptedUnsignedIntegerType>();
})
.def_classmethod("get", [](pybind11::object cls, MlirContext ctx,
unsigned width) {
MlirTypeOrError typeOrError =
fheEncryptedIntegerTypeGetChecked(ctx, width);
IntegerTypeGetChecked<EncryptedUnsignedIntegerType>(ctx, width);
if (typeOrError.isError) {
throw std::invalid_argument("can't create eint with the given width");
}
@@ -36,11 +66,13 @@ void mlir::concretelang::python::populateDialectFHESubmodule(
});
mlir_type_subclass(m, "EncryptedSignedIntegerType",
fheTypeIsAnEncryptedSignedIntegerType)
[](MlirType type) {
return unwrap(type).isa<EncryptedSignedIntegerType>();
})
.def_classmethod(
"get", [](pybind11::object cls, MlirContext ctx, unsigned width) {
MlirTypeOrError typeOrError =
fheEncryptedSignedIntegerTypeGetChecked(ctx, width);
IntegerTypeGetChecked<EncryptedSignedIntegerType>(ctx, width);
if (typeOrError.isError) {
throw std::invalid_argument(
"can't create esint with the given width");

View File

@@ -25,13 +25,10 @@ from .compilation_feedback import CompilationFeedback
from .key_set import KeySet
from .public_result import PublicResult
from .public_arguments import PublicArguments
from .jit_compilation_result import JITCompilationResult
from .jit_lambda import JITLambda
from .lambda_argument import LambdaArgument
from .library_compilation_result import LibraryCompilationResult
from .library_lambda import LibraryLambda
from .client_support import ClientSupport
from .jit_support import JITSupport
from .library_support import LibrarySupport
from .evaluation_keys import EvaluationKeys
from .value import Value

View File

@@ -44,8 +44,8 @@ class ClientSupport(WrapperCpp):
)
super().__init__(client_support)
@staticmethod
# pylint: disable=arguments-differ
@staticmethod
def new() -> "ClientSupport":
"""Build a ClientSupport.
@@ -176,7 +176,9 @@ class ClientSupport(WrapperCpp):
f"public_result must be of type PublicResult, not {type(public_result)}"
)
lambda_arg = LambdaArgument.wrap(
_ClientSupport.decrypt_result(keyset.cpp(), public_result.cpp())
_ClientSupport.decrypt_result(
client_parameters.cpp(), keyset.cpp(), public_result.cpp()
)
)
output_signs = client_parameters.output_signs()
@@ -216,7 +218,6 @@ class ClientSupport(WrapperCpp):
"""
# pylint: disable=too-many-return-statements,too-many-branches
if not isinstance(value, ACCEPTED_TYPES):
raise TypeError(
"value of lambda argument must be either int, numpy.array or numpy.(u)int{8,16,32,64}"

View File

@@ -39,8 +39,8 @@ class CompilationContext(WrapperCpp):
)
super().__init__(compilation_context)
@staticmethod
# pylint: disable=arguments-differ
@staticmethod
def new() -> "CompilationContext":
"""Build a CompilationContext.

View File

@@ -1,38 +0,0 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
"""JITCompilationResult."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
JITCompilationResult as _JITCompilationResult,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class JITCompilationResult(WrapperCpp):
"""JITCompilationResult holds the result of a JIT compilation.
It can be instrumented using the JITSupport to load client parameters and execute the compiled
code.
"""
def __init__(self, jit_compilation_result: _JITCompilationResult):
"""Wrap the native Cpp object.
Args:
jit_compilation_result (_JITCompilationResult): object to wrap
Raises:
TypeError: if jit_compilation_result is not of type _JITCompilationResult
"""
if not isinstance(jit_compilation_result, _JITCompilationResult):
raise TypeError(
f"jit_compilation_result must be of type _JITCompilationResult, not "
f"{type(jit_compilation_result)}"
)
super().__init__(jit_compilation_result)

View File

@@ -1,35 +0,0 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
"""JITLambda."""
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
JITLambda as _JITLambda,
)
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
class JITLambda(WrapperCpp):
"""JITLambda contains an in-memory executable code and can be ran using JITSupport.
It's an artifact of JIT compilation, which stays in memory and can be executed with the help of
JITSupport.
"""
def __init__(self, jit_lambda: _JITLambda):
"""Wrap the native Cpp object.
Args:
jit_lambda (_JITLambda): object to wrap
Raises:
TypeError: if jit_lambda is not of type JITLambda
"""
if not isinstance(jit_lambda, _JITLambda):
raise TypeError(
f"jit_lambda must be of type _JITLambda, not {type(jit_lambda)}"
)
super().__init__(jit_lambda)

View File

@@ -1,202 +0,0 @@
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
"""JITSupport.
Just-in-time compilation provide a way to compile and execute an MLIR program while keeping the executable
code in memory.
"""
from typing import Optional
# pylint: disable=no-name-in-module,import-error
from mlir._mlir_libs._concretelang._compiler import (
JITSupport as _JITSupport,
)
# pylint: enable=no-name-in-module,import-error
from .utils import lookup_runtime_lib
from .compilation_options import CompilationOptions
from .jit_compilation_result import JITCompilationResult
from .client_parameters import ClientParameters
from .compilation_feedback import CompilationFeedback
from .jit_lambda import JITLambda
from .public_arguments import PublicArguments
from .public_result import PublicResult
from .wrapper import WrapperCpp
from .evaluation_keys import EvaluationKeys
class JITSupport(WrapperCpp):
"""Support class for JIT compilation and execution."""
def __init__(self, jit_support: _JITSupport):
"""Wrap the native Cpp object.
Args:
jit_support (_JITSupport): object to wrap
Raises:
TypeError: if jit_support is not of type _JITSupport
"""
if not isinstance(jit_support, _JITSupport):
raise TypeError(
f"jit_support must be of type _JITSupport, not {type(jit_support)}"
)
super().__init__(jit_support)
@staticmethod
# pylint: disable=arguments-differ
def new(runtime_lib_path: Optional[str] = None) -> "JITSupport":
"""Build a JITSupport.
Args:
runtime_lib_path (Optional[str]): path to the runtime library. Defaults to None.
Raises:
TypeError: if runtime_lib_path is not of type str or None
Returns:
JITSupport
"""
if runtime_lib_path is None:
runtime_lib_path = lookup_runtime_lib()
else:
if not isinstance(runtime_lib_path, str):
raise TypeError(
f"runtime_lib_path must be of type str, not {type(runtime_lib_path)}"
)
return JITSupport.wrap(_JITSupport(runtime_lib_path))
# pylint: enable=arguments-differ
def compile(
self,
mlir_program: str,
options: CompilationOptions = CompilationOptions.new("main"),
) -> JITCompilationResult:
"""JIT compile an MLIR program using Concrete dialects.
Args:
mlir_program (str): textual representation of the mlir program to compile
options (CompilationOptions): compilation options
Raises:
TypeError: if mlir_program is not of type str
TypeError: if options is not of type CompilationOptions
Returns:
JITCompilationResult: the result of the JIT compilation
"""
if not isinstance(mlir_program, str):
raise TypeError(
f"mlir_program must be of type str, not {type(mlir_program)}"
)
if not isinstance(options, CompilationOptions):
raise TypeError(
f"options must be of type CompilationOptions, not {type(options)}"
)
return JITCompilationResult.wrap(
self.cpp().compile(mlir_program, options.cpp())
)
def load_client_parameters(
self, compilation_result: JITCompilationResult
) -> ClientParameters:
"""Load the client parameters from the JIT compilation result.
Args:
compilation_result (JITCompilationResult): result of the JIT compilation
Raises:
TypeError: if compilation_result is not of type JITCompilationResult
Returns:
ClientParameters: appropriate client parameters for the compiled program
"""
if not isinstance(compilation_result, JITCompilationResult):
raise TypeError(
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
)
return ClientParameters.wrap(
self.cpp().load_client_parameters(compilation_result.cpp())
)
def load_compilation_feedback(
self, compilation_result: JITCompilationResult
) -> CompilationFeedback:
"""Load the compilation feedback from the JIT compilation result.
Args:
compilation_result (JITCompilationResult): result of the JIT compilation
Raises:
TypeError: if compilation_result is not of type JITCompilationResult
Returns:
CompilationFeedback: the compilation feedback for the compiled program
"""
if not isinstance(compilation_result, JITCompilationResult):
raise TypeError(
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
)
return CompilationFeedback.wrap(
self.cpp().load_compilation_feedback(compilation_result.cpp())
)
def load_server_lambda(self, compilation_result: JITCompilationResult) -> JITLambda:
"""Load the JITLambda from the JIT compilation result.
Args:
compilation_result (JITCompilationResult): result of the JIT compilation.
Raises:
TypeError: if compilation_result is not of type JITCompilationResult
Returns:
JITLambda: loaded JITLambda to be executed
"""
if not isinstance(compilation_result, JITCompilationResult):
raise TypeError(
f"compilation_result must be a JITCompilationResult not {type(compilation_result)}"
)
return JITLambda.wrap(self.cpp().load_server_lambda(compilation_result.cpp()))
def server_call(
self,
jit_lambda: JITLambda,
public_arguments: PublicArguments,
evaluation_keys: EvaluationKeys,
) -> PublicResult:
"""Call the JITLambda with public_arguments.
Args:
jit_lambda (JITLambda): A server lambda to call.
public_arguments (PublicArguments): The arguments of the call.
evaluation_keys (EvaluationKeys): Evalutation keys of the call.
Raises:
TypeError: if jit_lambda is not of type JITLambda
TypeError: if public_arguments is not of type PublicArguments
TypeError: if evaluation_keys is not of type EvaluationKeys
Returns:
PublicResult: the result of the call of the server lambda.
"""
if not isinstance(jit_lambda, JITLambda):
raise TypeError(
f"jit_lambda must be of type JITLambda, not {type(jit_lambda)}"
)
if not isinstance(public_arguments, PublicArguments):
raise TypeError(
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
)
if not isinstance(evaluation_keys, EvaluationKeys):
raise TypeError(
f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}"
)
return PublicResult.wrap(
self.cpp().server_call(
jit_lambda.cpp(), public_arguments.cpp(), evaluation_keys.cpp()
)
)

View File

@@ -15,7 +15,6 @@ from mlir._mlir_libs._concretelang._compiler import (
# pylint: enable=no-name-in-module,import-error
from .wrapper import WrapperCpp
from .evaluation_keys import EvaluationKeys
from .client_parameters import ClientParameters
class KeySet(WrapperCpp):
@@ -64,10 +63,6 @@ class KeySet(WrapperCpp):
)
return KeySet.wrap(_KeySet.deserialize(serialized_key_set))
def client_parameters(self) -> ClientParameters:
"""Get client parameters of keyset."""
return self.cpp().client_parameters()
def get_evaluation_keys(self) -> EvaluationKeys:
"""
Get evaluation keys for execution.

View File

@@ -124,6 +124,8 @@ class LibrarySupport(WrapperCpp):
generateCppHeader,
)
)
if not os.path.isdir(output_path):
os.makedirs(output_path)
library_support.output_dir_path = output_path
return library_support
@@ -216,27 +218,29 @@ class LibrarySupport(WrapperCpp):
def load_compilation_feedback(
self, compilation_result: LibraryCompilationResult
) -> CompilationFeedback:
"""Load the compilation feedback from the JIT compilation result.
"""Load the compilation feedback from the compilation result.
Args:
compilation_result (JITCompilationResult): result of the JIT compilation
compilation_result (LibraryCompilationResult): result of the compilation
Raises:
TypeError: if compilation_result is not of type JITCompilationResult
TypeError: if compilation_result is not of type LibraryCompilationResult
Returns:
CompilationFeedback: the compilation feedback for the compiled program
"""
if not isinstance(compilation_result, LibraryCompilationResult):
raise TypeError(
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
f"compilation_result must be of type LibraryCompilationResult, not {type(compilation_result)}"
)
return CompilationFeedback.wrap(
self.cpp().load_compilation_feedback(compilation_result.cpp())
)
def load_server_lambda(
self, library_compilation_result: LibraryCompilationResult
self,
library_compilation_result: LibraryCompilationResult,
simulation: bool,
) -> LibraryLambda:
"""Load the server lambda from the library compilation result.
@@ -255,7 +259,7 @@ class LibrarySupport(WrapperCpp):
f"{type(library_compilation_result)}"
)
return LibraryLambda.wrap(
self.cpp().load_server_lambda(library_compilation_result.cpp())
self.cpp().load_server_lambda(library_compilation_result.cpp(), simulation)
)
def server_call(
@@ -340,10 +344,10 @@ class LibrarySupport(WrapperCpp):
"""
return self.cpp().get_shared_lib_path()
def get_client_parameters_path(self) -> str:
"""Get the path where the client parameters file is expected to be.
def get_program_info_path(self) -> str:
"""Get the path where the program info file is expected to be.
Returns:
str: path to the client parameters file
str: path to the program info file
"""
return self.cpp().get_client_parameters_path()
return self.cpp().get_program_info_path()

View File

@@ -2,7 +2,7 @@
# pylint: disable=no-name-in-module,import-error
from typing import List, Union
from typing import Union
import numpy as np
from mlir._mlir_libs._concretelang._compiler import (
@@ -65,47 +65,16 @@ class SimulatedValueDecrypter(WrapperCpp):
decrypted value
"""
shape = tuple(self.cpp().get_shape(position))
lambda_arg = self.cpp().decrypt(position, value.cpp())
is_signed = lambda_arg.is_signed()
if lambda_arg.is_scalar():
return (
lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar()
)
if len(shape) == 0:
return self.decrypt_scalar(position, value)
return np.array(self.decrypt_tensor(position, value), dtype=np.int64).reshape(
shape
shape = lambda_arg.get_tensor_shape()
return (
np.array(lambda_arg.get_signed_tensor_data()).reshape(shape)
if is_signed
else np.array(lambda_arg.get_tensor_data()).reshape(shape)
)
def decrypt_scalar(self, position: int, value: Value) -> int:
"""
Decrypt scalar.
Args:
position (int):
position of the argument within the circuit
value (Value):
scalar value to decrypt
Returns:
int:
decrypted scalar
"""
return self.cpp().decrypt_scalar(position, value.cpp())
def decrypt_tensor(self, position: int, value: Value) -> List[int]:
"""
Decrypt tensor.
Args:
position (int):
position of the argument within the circuit
value (Value):
tensor value to decrypt
Returns:
List[int]:
decrypted tensor
"""
return self.cpp().decrypt_tensor(position, value.cpp())

View File

@@ -2,7 +2,7 @@
# pylint: disable=no-name-in-module,import-error
from typing import List, Union
from typing import Union
import numpy as np
from mlir._mlir_libs._concretelang._compiler import (
@@ -66,47 +66,16 @@ class ValueDecrypter(WrapperCpp):
decrypted value
"""
shape = tuple(self.cpp().get_shape(position))
lambda_arg = self.cpp().decrypt(position, value.cpp())
is_signed = lambda_arg.is_signed()
if lambda_arg.is_scalar():
return (
lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar()
)
if len(shape) == 0:
return self.decrypt_scalar(position, value)
return np.array(self.decrypt_tensor(position, value), dtype=np.int64).reshape(
shape
shape = lambda_arg.get_tensor_shape()
return (
np.array(lambda_arg.get_signed_tensor_data()).reshape(shape)
if is_signed
else np.array(lambda_arg.get_tensor_data()).reshape(shape)
)
def decrypt_scalar(self, position: int, value: Value) -> int:
"""
Decrypt scalar.
Args:
position (int):
position of the argument within the circuit
value (Value):
scalar value to decrypt
Returns:
int:
decrypted scalar
"""
return self.cpp().decrypt_scalar(position, value.cpp())
def decrypt_tensor(self, position: int, value: Value) -> List[int]:
"""
Decrypt tensor.
Args:
position (int):
position of the argument within the circuit
value (Value):
tensor value to decrypt
Returns:
List[int]:
decrypted tensor
"""
return self.cpp().decrypt_tensor(position, value.cpp())

View File

@@ -1,10 +0,0 @@
[package]
name = "concrete-compiler"
version = "0.1.0"
edition = "2021"
[build-dependencies]
bindgen = "0.60.1"
[dev-dependencies]
tempdir = "0.3.7"

View File

@@ -1,12 +0,0 @@
# Rust Bindings
A Rust library providing an API to the Concrete Compiler.
### Build
Set `CONCRETE_COMPILER_INSTALL_DIR` to the right path before building with `cargo`
```bash
$ export CONCRETE_COMPILER_INSTALL_DIR=/installation/path/concretecompiler/
$ cargo build --release
```

View File

@@ -1,35 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <concretelang-c/Dialect/FHE.h>
#include <concretelang-c/Dialect/FHELinalg.h>
#include <concretelang-c/Support/CompilerEngine.h>
#include <mlir-c/AffineExpr.h>
#include <mlir-c/AffineMap.h>
#include <mlir-c/BuiltinAttributes.h>
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Conversion.h>
#include <mlir-c/Debug.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/Async.h>
#include <mlir-c/Dialect/ControlFlow.h>
#include <mlir-c/Dialect/Func.h>
#include <mlir-c/Dialect/GPU.h>
#include <mlir-c/Dialect/LLVM.h>
#include <mlir-c/Dialect/Linalg.h>
#include <mlir-c/Dialect/PDL.h>
#include <mlir-c/Dialect/Quant.h>
#include <mlir-c/Dialect/SCF.h>
#include <mlir-c/Dialect/Shape.h>
#include <mlir-c/Dialect/SparseTensor.h>
#include <mlir-c/Dialect/Tensor.h>
#include <mlir-c/ExecutionEngine.h>
#include <mlir-c/IR.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Interfaces.h>
#include <mlir-c/Pass.h>
#include <mlir-c/RegisterEverything.h>
#include <mlir-c/Support.h>
#include <mlir-c/Transforms.h>

View File

@@ -1,378 +0,0 @@
extern crate bindgen;
use std::env;
use std::error::Error;
use std::path::Path;
use std::process::exit;
const MLIR_STATIC_LIBS: &[&str] = &[
"MLIRArithAttrToLLVMConversion",
"MLIRDestinationStyleOpInterface",
"MLIRVectorTransformOps",
"MLIRMemRefTransformOps",
"MLIRGPUTransformOps",
"MLIRAffineTransformOps",
"MLIRBytecodeReader",
"MLIRAsmParser",
"MLIRIndexDialect",
"MLIRMaskableOpInterface",
"MLIRMaskingOpInterface",
"MLIRInferIntRangeCommon",
"MLIRShapedOpInterfaces",
"MLIRTransformDialectUtils",
"MLIRParallelCombiningOpInterface",
"MLIRMemRefDialect",
"MLIRVectorToSPIRV",
"MLIRControlFlowInterfaces",
"MLIRLinalgToStandard",
"MLIRAnalysis",
"MLIRSPIRVDeserialization",
"MLIRTransformDialect",
"MLIRSparseTensorPipelines",
"MLIRVectorToGPU",
"MLIRTranslateLib",
"MLIRPass",
"MLIRComplexToLibm",
"MLIRInferTypeOpInterface",
"MLIRMemRefToSPIRV",
"MLIRAMDGPUToROCDL",
"MLIRBufferizationTransformOps",
"MLIRExecutionEngineUtils",
"MLIRNVVMDialect",
"MLIRSCFUtils",
"MLIRLinalgTransforms",
"MLIRParser",
"MLIRFuncTransforms",
"MLIRTosaTestPasses",
"MLIRTosaToArith",
"MLIRTensorDialect",
"MLIRGPUTransforms",
"MLIRLowerableDialectsToLLVM",
"MLIRBufferizationToMemRef",
"MLIRPresburger",
"MLIRFuncDialect",
"MLIRPDLToPDLInterp",
"MLIRArithTransforms",
"MLIRViewLikeInterface",
"MLIRTargetCpp",
"MLIROpenMPToLLVM",
"MLIRSPIRVConversion",
"MLIRNVGPUTransforms",
"MLIRSparseTensorTransforms",
"MLIRAffineAnalysis",
"MLIRArmSVETransforms",
"MLIRArmNeon2dToIntr",
"MLIRDataLayoutInterfaces",
"MLIRAffineTransforms",
"MLIROpenACCToLLVMIRTranslation",
"MLIRTensorUtils",
"MLIRSPIRVSerialization",
"MLIRShapeToStandard",
"MLIRArithToSPIRV",
"MLIRArithDialect",
"MLIRFuncToSPIRV",
"MLIRQuantUtils",
"MLIRTensorTilingInterfaceImpl",
"MLIRX86VectorToLLVMIRTranslation",
"MLIRCopyOpInterface",
"MLIRMathToLibm",
"MLIRGPUToGPURuntimeTransforms",
"MLIRLLVMDialect",
"MLIRAffineDialect",
"MLIRTransforms",
"MLIRVectorTransforms",
"MLIROpenMPDialect",
"MLIRControlFlowDialect",
"MLIRVectorUtils",
"MLIRROCDLDialect",
"MLIRPDLDialect",
"MLIRAsyncDialect",
"MLIRLinalgToLLVM",
"MLIROpenACCDialect",
"MLIRVectorDialect",
"MLIROpenACCToSCF",
"MLIRIR",
"MLIRCAPIIR",
"MLIRTargetLLVMIRImport",
"MLIRTensorToLinalg",
"MLIRCallInterfaces",
"MLIRTensorInferTypeOpInterfaceImpl",
"MLIRTransformDialectTransforms",
"MLIRComplexDialect",
"MLIRAffineUtils",
"MLIRLoopLikeInterface",
"MLIRDialect",
"MLIRLinalgUtils",
"MLIRSCFToSPIRV",
"MLIRAffineToStandard",
"MLIRX86VectorDialect",
"MLIRGPUToVulkanTransforms",
"MLIRRewrite",
"MLIRAMXToLLVMIRTranslation",
"MLIRInferIntRangeInterface",
"MLIRCAPIRegisterEverything",
"MLIRNVVMToLLVMIRTranslation",
"MLIRAsyncTransforms",
"MLIRPDLInterpDialect",
"MLIRTransformUtils",
"MLIRLinalgDialect",
"MLIRMathDialect",
"MLIRMemRefTransforms",
"MLIRSPIRVModuleCombiner",
"MLIRMathToLLVM",
"MLIRControlFlowToLLVM",
"MLIRArmSVEDialect",
"MLIRSPIRVTranslateRegistration",
"MLIRToLLVMIRTranslationRegistration",
"MLIRSCFDialect",
"MLIRTilingInterface",
"MLIREmitCDialect",
"MLIRTableGen",
"MLIRTosaToSCF",
"MLIROpenMPToLLVMIRTranslation",
"MLIRSupport",
"MLIROpenACCToLLVM",
"MLIRAMDGPUDialect",
"MLIRTosaToLinalg",
"MLIRSparseTensorUtils",
"MLIRFuncToLLVM",
"MLIRTargetLLVMIRExport",
"MLIRControlFlowToSPIRV",
"MLIRReconcileUnrealizedCasts",
"MLIRComplexToStandard",
"MLIRMathTransforms",
"MLIRSPIRVUtils",
"MLIRCastInterfaces",
"MLIRTosaToTensor",
"MLIRGPUToSPIRV",
"MLIRBufferizationDialect",
"MLIRSCFToControlFlow",
"MLIRArmSVEToLLVMIRTranslation",
"MLIRExecutionEngine",
"MLIRBufferizationTransforms",
"MLIRSparseTensorDialect",
"MLIRTensorToSPIRV",
"MLIRVectorToSCF",
"MLIRLLVMToLLVMIRTranslation",
"MLIRNVGPUDialect",
"MLIRAsyncToLLVM",
"MLIRAMXDialect",
"MLIRLinalgTransformOps",
"MLIRMathToSPIRV",
"MLIRSCFToOpenMP",
"MLIRShapeDialect",
"MLIRGPUToROCDLTransforms",
"MLIRGPUToNVVMTransforms",
"MLIRTensorTransforms",
"MLIRSCFToGPU",
"MLIRDialectUtils",
"MLIRNVGPUToNVVM",
"MLIRTosaDialect",
"MLIRVectorToLLVM",
"MLIRSPIRVDialect",
"MLIRSideEffectInterfaces",
"MLIRQuantDialect",
"MLIRSCFTransforms",
"MLIRMLProgramDialect",
"MLIRDLTIDialect",
"MLIRLinalgFrontend",
"MLIRROCDLToLLVMIRTranslation",
"MLIRArmNeonDialect",
"MLIRSPIRVToLLVM",
"MLIRLLVMIRTransforms",
"MLIRTosaTransforms",
"MLIRLLVMCommonConversion",
"MLIRSCFTransformOps",
"MLIRArmNeonToLLVMIRTranslation",
"MLIRAMXTransforms",
"MLIRSPIRVTransforms",
"MLIRMemRefToLLVM",
"MLIRSPIRVBinaryUtils",
"MLIRArithUtils",
"MLIRVectorInterfaces",
"MLIRGPUOps",
"MLIRComplexToLLVM",
"MLIRShapeOpsTransforms",
"MLIRX86VectorTransforms",
"MLIRArithToLLVM",
];
const LLVM_STATIC_LIBS: &[&str] = &[
"LLVMAggressiveInstCombine",
"LLVMAnalysis",
"LLVMAsmParser",
"LLVMAsmPrinter",
"LLVMBinaryFormat",
"LLVMBitReader",
"LLVMBitstreamReader",
"LLVMBitWriter",
"LLVMCFGuard",
"LLVMCodeGen",
"LLVMCore",
"LLVMCoroutines",
"LLVMDebugInfoCodeView",
"LLVMDebugInfoDWARF",
"LLVMDebugInfoMSF",
"LLVMDebugInfoPDB",
"LLVMDemangle",
"LLVMExecutionEngine",
"LLVMFrontendOpenMP",
"LLVMGlobalISel",
"LLVMInstCombine",
"LLVMInstrumentation",
"LLVMipo",
"LLVMIRReader",
"LLVMJITLink",
"LLVMLinker",
"LLVMMC",
"LLVMMCDisassembler",
"LLVMMCParser",
"LLVMObjCARCOpts",
"LLVMObject",
"LLVMOption",
"LLVMOrcJIT",
"LLVMOrcShared",
"LLVMOrcTargetProcess",
"LLVMPasses",
"LLVMProfileData",
"LLVMRemarks",
"LLVMRuntimeDyld",
"LLVMScalarOpts",
"LLVMSelectionDAG",
"LLVMSupport",
"LLVMSymbolize",
"LLVMTableGen",
"LLVMTableGenGlobalISel",
"LLVMTarget",
"LLVMTargetParser",
"LLVMTextAPI",
"LLVMTransformUtils",
"LLVMVectorize",
];
#[cfg(target_arch = "aarch64")]
const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &[
"LLVMAArch64Utils",
"LLVMAArch64Info",
"LLVMAArch64Desc",
"LLVMAArch64CodeGen",
];
#[cfg(target_arch = "x86_64")]
const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &["LLVMX86CodeGen", "LLVMX86Desc", "LLVMX86Info"];
const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
"AnalysisUtils",
"RTDialect",
"RTDialectTransforms",
"ConcretelangSupport",
"ConcreteToCAPI",
"ConcretelangConversion",
"ConcretelangTransforms",
"FHETensorOpsToLinalg",
"ConcretelangServerLib",
"CONCRETELANGCAPIFHE",
"TFHEGlobalParametrization",
"ConcretelangClientLib",
"ConcretelangConcreteTransforms",
"ConcretelangSDFGInterfaces",
"ConcretelangSDFGTransforms",
"CONCRETELANGCAPISupport",
"FHELinalgDialect",
"TracingDialect",
"TracingDialectTransforms",
"TracingToCAPI",
"ConcretelangInterfaces",
"TFHEDialect",
"SimulateTFHE",
"CONCRETELANGCAPIFHELINALG",
"FHELinalgDialectTransforms",
"FHEDialect",
"FHEInterfaces",
"FHEDialectTransforms",
"TFHEToConcrete",
"FHEToTFHECrt",
"FHEToTFHEScalar",
"TFHEDialectTransforms",
"TFHEKeyNormalization",
"concrete_optimizer",
"LinalgExtras",
"FHEDialectAnalysis",
"ConcreteDialect",
"RTDialectAnalysis",
"SDFGDialect",
"ExtractSDFGOps",
"SDFGToStreamEmulator",
"TFHEDialectAnalysis",
"ConcreteDialectAnalysis",
];
fn main() {
if let Err(error) = run() {
eprintln!("{}", error);
exit(1);
}
}
fn run() -> Result<(), Box<dyn Error>> {
let mut include_paths = Vec::new();
// if set, use installation path of concrete compiler to lookup libraries and include files
match env::var("CONCRETE_COMPILER_INSTALL_DIR") {
Ok(install_dir) => {
println!("cargo:rustc-link-search={}/lib/", install_dir);
include_paths.push(Path::new(&format!("{}/include/", install_dir)).to_path_buf());
}
Err(_e) => println!(
"cargo:warning=You are not setting CONCRETE_COMPILER_INSTALL_DIR, \
so your compiler/linker will have to lookup libs and include dirs on their own"
),
}
// linking to static libs
let all_static_libs = CONCRETE_COMPILER_STATIC_LIBS
.into_iter()
.chain(MLIR_STATIC_LIBS)
.chain(LLVM_STATIC_LIBS)
.chain(LLVM_TARGET_SPECIFIC_STATIC_LIBS);
for static_lib_name in all_static_libs {
println!("cargo:rustc-link-lib=static={}", static_lib_name);
}
// concrete compiler runtime
println!("cargo:rustc-link-lib=ConcretelangRuntime");
// concrete optimizer
// `-bundle` serve to not have multiple definition issues
println!("cargo:rustc-link-lib=static:-bundle=concrete_optimizer_cpp");
// required by llvm
println!("cargo:rustc-link-lib=ncurses");
if let Some(name) = get_system_libcpp() {
println!("cargo:rustc-link-lib={}", name);
}
// zlib
println!("cargo:rustc-link-lib=z");
println!("cargo:rerun-if-changed=api.h");
bindgen::builder()
.header("api.h")
.clang_args(
include_paths
.into_iter()
.map(|path| format!("-I{}", path.to_str().unwrap())),
)
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
.generate()
.unwrap()
.write_to_file(Path::new(&env::var("OUT_DIR")?).join("bindings.rs"))?;
Ok(())
}
fn get_system_libcpp() -> Option<&'static str> {
if cfg!(target_env = "msvc") {
None
} else if cfg!(target_os = "macos") {
Some("c++")
} else {
Some("stdc++")
}
}

View File

@@ -1,768 +0,0 @@
//! FHE dialect module
use crate::mlir::ffi::*;
use crate::mlir::*;
pub fn create_fhe_add_eint_op(
context: MlirContext,
lhs: MlirValue,
rhs: MlirValue,
) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(lhs)];
// infer result type from operands
create_op(
context,
"FHE.add_eint",
&[lhs, rhs],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_add_eint_int_op(
context: MlirContext,
lhs: MlirValue,
rhs: MlirValue,
) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(lhs)];
// infer result type from operands
create_op(
context,
"FHE.add_eint_int",
&[lhs, rhs],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_sub_eint_op(
context: MlirContext,
lhs: MlirValue,
rhs: MlirValue,
) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(lhs)];
// infer result type from operands
create_op(
context,
"FHE.sub_eint",
&[lhs, rhs],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_sub_eint_int_op(
context: MlirContext,
lhs: MlirValue,
rhs: MlirValue,
) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(lhs)];
// infer result type from operands
create_op(
context,
"FHE.sub_eint_int",
&[lhs, rhs],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_sub_int_eint_op(
context: MlirContext,
lhs: MlirValue,
rhs: MlirValue,
) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(rhs)];
// infer result type from operands
create_op(
context,
"FHE.sub_int_eint",
&[lhs, rhs],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_negate_eint_op(context: MlirContext, eint: MlirValue) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(eint)];
// infer result type from operands
create_op(
context,
"FHE.neg_eint",
&[eint],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_mul_eint_int_op(
context: MlirContext,
lhs: MlirValue,
rhs: MlirValue,
) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(lhs)];
// infer result type from operands
create_op(
context,
"FHE.mul_eint_int",
&[lhs, rhs],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_apply_lut_op(
context: MlirContext,
eint: MlirValue,
lut: MlirValue,
result_type: MlirType,
) -> MlirOperation {
create_op(
context,
"FHE.apply_lookup_table",
&[eint, lut],
[result_type].as_slice(),
&[],
false,
)
}
#[derive(Debug)]
pub enum FHEError {
InvalidFHEType,
InvalidWidth,
}
pub fn convert_eint_to_esint_type(
context: MlirContext,
eint_type: MlirType,
) -> Result<MlirType, FHEError> {
unsafe {
let width = fheTypeIntegerWidthGet(eint_type);
if width == 0 {
return Err(FHEError::InvalidFHEType);
}
let type_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, width);
if type_or_error.isError {
Err(FHEError::InvalidWidth)
} else {
Ok(type_or_error.type_)
}
}
}
pub fn convert_esint_to_eint_type(
context: MlirContext,
esint_type: MlirType,
) -> Result<MlirType, FHEError> {
unsafe {
let width = fheTypeIntegerWidthGet(esint_type);
if width == 0 {
return Err(FHEError::InvalidFHEType);
}
let type_or_error = fheEncryptedIntegerTypeGetChecked(context, width);
if type_or_error.isError {
Err(FHEError::InvalidWidth)
} else {
Ok(type_or_error.type_)
}
}
}
pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOperation {
unsafe {
let results = [convert_eint_to_esint_type(context, mlirValueGetType(eint)).unwrap()];
// infer result type from operands
create_op(
context,
"FHE.to_signed",
&[eint],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_to_unsigned_op(context: MlirContext, esint: MlirValue) -> MlirOperation {
unsafe {
let results = [convert_esint_to_eint_type(context, mlirValueGetType(esint)).unwrap()];
// infer result type from operands
create_op(
context,
"FHE.to_unsigned",
&[esint],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_fhe_zero_eint_op(context: MlirContext, result_type: MlirType) -> MlirOperation {
create_op(
context,
"FHE.zero",
&[],
[result_type].as_slice(),
&[],
false,
)
}
pub fn create_fhe_zero_eint_tensor_op(
context: MlirContext,
result_type: MlirType,
) -> MlirOperation {
create_op(
context,
"FHE.zero_tensor",
&[],
[result_type].as_slice(),
&[],
false,
)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_invalid_fhe_eint_type() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
let invalid_eint = fheEncryptedIntegerTypeGetChecked(context, 0);
assert!(invalid_eint.isError);
}
}
#[test]
fn test_valid_fhe_eint_type() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
assert!(!eint_or_error.isError);
let eint = eint_or_error.type_;
assert!(fheTypeIsAnEncryptedIntegerType(eint));
assert!(!fheTypeIsAnEncryptedSignedIntegerType(eint));
assert_eq!(fheTypeIntegerWidthGet(eint), 5);
let printed_eint = super::print_mlir_type_to_string(eint);
let expected_eint = "!FHE.eint<5>";
assert_eq!(printed_eint, expected_eint);
}
}
#[test]
fn test_valid_fhe_esint_type() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 5);
assert!(!esint_or_error.isError);
let esint = esint_or_error.type_;
assert!(fheTypeIsAnEncryptedSignedIntegerType(esint));
assert!(!fheTypeIsAnEncryptedIntegerType(esint));
assert_eq!(fheTypeIntegerWidthGet(esint), 5);
let printed_esint = super::print_mlir_type_to_string(esint);
let expected_esint = "!FHE.esint<5>";
assert_eq!(printed_esint, expected_esint);
}
}
#[test]
fn test_fhe_func() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 5-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
assert!(!eint_or_error.isError);
let eint = eint_or_error.type_;
// set input/output types of the FHE circuit
let func_input_types = [eint, eint];
let func_output_types = [eint];
// create the func operation
let func_op = create_func_with_block(
context,
"main",
func_input_types.as_slice(),
func_output_types.as_slice(),
);
let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
let func_args = [
mlirBlockGetArgument(func_block, 0),
mlirBlockGetArgument(func_block, 1),
];
// create an FHE add_eint op and append it to the function block
let add_eint_op = create_fhe_add_eint_op(context, func_args[0], func_args[1]);
mlirBlockAppendOwnedOperation(func_block, add_eint_op);
// create ret operation and append it to the block
let ret_op = create_ret_op(context, mlirOperationGetResult(add_eint_op, 0));
mlirBlockAppendOwnedOperation(func_block, ret_op);
// create module to hold the previously created function
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
let printed_module =
super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
let expected_module = "\
module {
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}
}
";
assert_eq!(printed_module, expected_module);
}
}
#[test]
fn test_zero_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
let printed_op = print_mlir_operation_to_string(zero_op);
let expected_op = "%0 = \"FHE.zero\"() : () -> !FHE.eint<6>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_zero_tensor_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 4-bit eint tensor type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 4);
assert!(!eint_or_error.isError);
let eint = eint_or_error.type_;
let shape: [i64; 3] = [60, 66, 73];
let location = mlirLocationUnknownGet(context);
let eint_tensor = mlirRankedTensorTypeGetChecked(
location,
3,
shape.as_ptr(),
eint,
mlirAttributeGetNull(),
);
let zero_op = create_fhe_zero_eint_tensor_op(context, eint_tensor);
let printed_op = print_mlir_operation_to_string(zero_op);
let expected_op = "%0 = \"FHE.zero_tensor\"() : () -> tensor<60x66x73x!FHE.eint<4>>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_add_eint_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// add eint with itself
let add_eint_op = create_fhe_add_eint_op(context, eint_value, eint_value);
mlirBlockAppendOwnedOperation(main_block, add_eint_op);
let printed_op = print_mlir_operation_to_string(add_eint_op);
let expected_op =
"%1 = \"FHE.add_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_add_eint_int_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an int via a constant op
let cst_op = create_constant_int_op(context, 73, 7);
mlirBlockAppendOwnedOperation(main_block, cst_op);
let int_value = mlirOperationGetResult(cst_op, 0);
// add eint int
let add_eint_int_op = create_fhe_add_eint_int_op(context, eint_value, int_value);
mlirBlockAppendOwnedOperation(main_block, add_eint_int_op);
let printed_op = print_mlir_operation_to_string(add_eint_int_op);
let expected_op =
"%1 = \"FHE.add_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_sub_eint_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// sub eint with itself
let sub_eint_op = create_fhe_sub_eint_op(context, eint_value, eint_value);
mlirBlockAppendOwnedOperation(main_block, sub_eint_op);
let printed_op = print_mlir_operation_to_string(sub_eint_op);
let expected_op =
"%1 = \"FHE.sub_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_sub_eint_int_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an int via a constant op
let cst_op = create_constant_int_op(context, 73, 7);
mlirBlockAppendOwnedOperation(main_block, cst_op);
let int_value = mlirOperationGetResult(cst_op, 0);
// sub eint int
let sub_eint_int_op = create_fhe_sub_eint_int_op(context, eint_value, int_value);
mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op);
let printed_op = print_mlir_operation_to_string(sub_eint_int_op);
let expected_op =
"%1 = \"FHE.sub_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_sub_int_eint_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an int via a constant op
let cst_op = create_constant_int_op(context, 73, 7);
mlirBlockAppendOwnedOperation(main_block, cst_op);
let int_value = mlirOperationGetResult(cst_op, 0);
// sub int eint
let sub_eint_int_op = create_fhe_sub_int_eint_op(context, int_value, eint_value);
mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op);
let printed_op = print_mlir_operation_to_string(sub_eint_int_op);
let expected_op =
"%1 = \"FHE.sub_int_eint\"(%c-55_i7, %0) : (i7, !FHE.eint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_negate_eint_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// negate eint
let neg_eint_op = create_fhe_negate_eint_op(context, eint_value);
mlirBlockAppendOwnedOperation(main_block, neg_eint_op);
let printed_op = print_mlir_operation_to_string(neg_eint_op);
let expected_op = "%1 = \"FHE.neg_eint\"(%0) : (!FHE.eint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_mul_eint_int_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an int via a constant op
let cst_op = create_constant_int_op(context, 73, 7);
mlirBlockAppendOwnedOperation(main_block, cst_op);
let int_value = mlirOperationGetResult(cst_op, 0);
// mul eint int
let mul_eint_int_op = create_fhe_mul_eint_int_op(context, eint_value, int_value);
mlirBlockAppendOwnedOperation(main_block, mul_eint_int_op);
let printed_op = print_mlir_operation_to_string(mul_eint_int_op);
let expected_op =
"%1 = \"FHE.mul_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_to_signed_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// to signed
let to_signed_op = create_fhe_to_signed_op(context, eint_value);
mlirBlockAppendOwnedOperation(main_block, to_signed_op);
let printed_op = print_mlir_operation_to_string(to_signed_op);
let expected_op = "%1 = \"FHE.to_signed\"(%0) : (!FHE.eint<6>) -> !FHE.esint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_to_unsigned_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit esint type
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 6);
assert!(!esint_or_error.isError);
let esint6_type = esint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, esint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let esint_value = mlirOperationGetResult(zero_op, 0);
// to unsigned
let to_unsigned_op = create_fhe_to_unsigned_op(context, esint_value);
mlirBlockAppendOwnedOperation(main_block, to_unsigned_op);
let printed_op = print_mlir_operation_to_string(to_unsigned_op);
let expected_op = "%1 = \"FHE.to_unsigned\"(%0) : (!FHE.esint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_apply_lut_op() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an lut
let table: [i64; 64] = [0; 64];
let constant_lut_op = create_constant_flat_tensor_op(context, &table, 64);
mlirBlockAppendOwnedOperation(main_block, constant_lut_op);
let lut = mlirOperationGetResult(constant_lut_op, 0);
// LUT op
let apply_lut_op = create_fhe_apply_lut_op(context, eint_value, lut, eint6_type);
mlirBlockAppendOwnedOperation(main_block, apply_lut_op);
let printed_op = print_mlir_operation_to_string(apply_lut_op);
let expected_op = "%1 = \"FHE.apply_lookup_table\"(%0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
}

View File

@@ -1,4 +0,0 @@
pub mod compiler;
pub mod fhe;
pub mod fhelinalg;
pub mod mlir;

View File

@@ -1,443 +0,0 @@
//! MLIR module
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
pub mod ffi {
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
}
use ffi::*;
use std::ffi::CString;
use std::ops::AddAssign;
pub(crate) unsafe extern "C" fn mlir_rust_string_receiver_callback(
mlirStrRef: MlirStringRef,
user_data: *mut ::std::os::raw::c_void,
) {
let rust_string = &mut *(user_data as *mut String);
let slc = std::slice::from_raw_parts(mlirStrRef.data as *const u8, mlirStrRef.length as usize);
rust_string.add_assign(&String::from_utf8_lossy(slc));
}
pub fn print_mlir_operation_to_string(op: MlirOperation) -> String {
let mut rust_string = String::default();
let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void;
unsafe {
mlirOperationPrint(op, Some(mlir_rust_string_receiver_callback), receiver_ptr);
}
rust_string
}
pub fn print_mlir_type_to_string(mlir_type: MlirType) -> String {
let mut rust_string = String::default();
let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void;
unsafe {
mlirTypePrint(
mlir_type,
Some(mlir_rust_string_receiver_callback),
receiver_ptr,
);
}
rust_string
}
/// Returns a function operation with a region that contains a block.
///
/// The function would be defined using the provided input and output types. The main block of the
/// function can be later fetched, from which we can get function arguments, and it will be where
/// we append operations.
///
/// # Examples
/// ```
/// use concrete_compiler::mlir::*;
/// use concrete_compiler::mlir::ffi::*;
/// unsafe{
/// let context = mlirContextCreate();
/// register_all_dialects(context);
///
/// // input/output types
/// let func_input_types = [
/// mlirIntegerTypeGet(context, 64),
/// mlirIntegerTypeGet(context, 64),
/// ];
/// let func_output_types = [mlirIntegerTypeGet(context, 64)];
///
/// let func_op = create_func_with_block(
/// context,
/// "test",
/// func_input_types.as_slice(),
/// func_output_types.as_slice(),
/// );
///
/// // we can fetch the main block of the function from the function region
/// let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
/// // we can get arguments to later be used as operands to other operations
/// let func_args = [
/// mlirBlockGetArgument(func_block, 0),
/// mlirBlockGetArgument(func_block, 1),
/// ];
/// // to add an operation to the function, we will append it to the main block
/// let addi_op = create_addi_op(context, func_args[0], func_args[1]);
/// mlirBlockAppendOwnedOperation(func_block, addi_op);
/// }
/// ```
///
pub fn create_func_with_block(
context: MlirContext,
func_name: &str,
func_input_types: &[MlirType],
func_output_types: &[MlirType],
) -> MlirOperation {
unsafe {
// create the main block of the function
let locations = (0..func_input_types.len())
.into_iter()
.map(|_| mlirLocationUnknownGet(context))
.collect::<Vec<_>>();
let func_block = mlirBlockCreate(
func_input_types.len().try_into().unwrap(),
func_input_types.as_ptr(),
locations.as_ptr(),
);
// create region to hold the previously created block
let func_region = mlirRegionCreate();
mlirRegionAppendOwnedBlock(func_region, func_block);
// create function to hold the previously created region
let location = mlirLocationUnknownGet(context);
let func_str = CString::new("func.func").unwrap();
let mut func_op_state =
mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location);
mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr());
// set function attributes
let func_type_str = CString::new("function_type").unwrap();
let sym_name_str = CString::new("sym_name").unwrap();
let func_name_str = CString::new(func_name).unwrap();
let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet(
context,
func_input_types.len().try_into().unwrap(),
func_input_types.as_ptr(),
func_output_types.len().try_into().unwrap(),
func_output_types.as_ptr(),
));
let sym_name_attr = mlirStringAttrGet(
context,
mlirStringRefCreateFromCString(func_name_str.as_ptr()),
);
mlirOperationStateAddAttributes(
&mut func_op_state,
2,
[
// func type
mlirNamedAttributeGet(
mlirIdentifierGet(
context,
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
),
func_type_attr,
),
// func name
mlirNamedAttributeGet(
mlirIdentifierGet(
context,
mlirStringRefCreateFromCString(sym_name_str.as_ptr()),
),
sym_name_attr,
),
]
.as_ptr(),
);
let func_op = mlirOperationCreate(&mut func_op_state);
func_op
}
}
/// Generic function to create an MLIR operation.
///
/// Create an MLIR operation based on its mnemonic (e.g. addi), it's operands, result types, and
/// attributes. Result types can be inferred automatically if the operation itself supports that.
pub fn create_op(
context: MlirContext,
mnemonic: &str,
operands: &[MlirValue],
results: &[MlirType],
attrs: &[MlirNamedAttribute],
auto_result_type_inference: bool,
) -> MlirOperation {
let op_mnemonic = CString::new(mnemonic).unwrap();
unsafe {
let location = mlirLocationUnknownGet(context);
let mut op_state = mlirOperationStateGet(
mlirStringRefCreateFromCString(op_mnemonic.as_ptr()),
location,
);
mlirOperationStateAddOperands(
&mut op_state,
operands.len().try_into().unwrap(),
operands.as_ptr(),
);
mlirOperationStateAddAttributes(
&mut op_state,
attrs.len().try_into().unwrap(),
attrs.as_ptr(),
);
if auto_result_type_inference {
mlirOperationStateEnableResultTypeInference(&mut op_state);
} else {
mlirOperationStateAddResults(
&mut op_state,
results.len().try_into().unwrap(),
results.as_ptr(),
);
}
mlirOperationCreate(&mut op_state)
}
}
pub fn create_addi_op(context: MlirContext, lhs: MlirValue, rhs: MlirValue) -> MlirOperation {
create_op(context, "arith.addi", &[lhs, rhs], &[], &[], true)
}
pub fn create_ret_op(context: MlirContext, ret_value: MlirValue) -> MlirOperation {
create_op(context, "func.return", &[ret_value], &[], &[], false)
}
pub fn create_constant_int_op(context: MlirContext, cst_value: i64, width: u32) -> MlirOperation {
unsafe {
let result_type = mlirIntegerTypeGet(context, width);
let value_str = CString::new("value").unwrap();
let value_attr = mlirNamedAttributeGet(
mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())),
mlirIntegerAttrGet(result_type, cst_value),
);
create_op(
context,
"arith.constant",
&[],
&[result_type],
&[value_attr],
true,
)
}
}
pub fn create_constant_flat_tensor_op(
context: MlirContext,
cst_table: &[i64],
bitwidth: u32,
) -> MlirOperation {
let shape = [cst_table.len().try_into().unwrap()];
create_constant_tensor_op(context, &shape, cst_table, bitwidth)
}
pub fn create_constant_tensor_op(
context: MlirContext,
shape: &[i64],
cst_table: &[i64],
bitwidth: u32,
) -> MlirOperation {
unsafe {
let result_type = mlirRankedTensorTypeGet(
shape.len().try_into().unwrap(),
shape.as_ptr(),
mlirIntegerTypeGet(context, bitwidth),
mlirAttributeGetNull(),
);
let cst_table_attrs: Vec<MlirAttribute> = cst_table
.into_iter()
.map(|value| mlirIntegerAttrGet(mlirIntegerTypeGet(context, bitwidth), *value))
.collect();
let value_str = CString::new("value").unwrap();
let value_attr = mlirNamedAttributeGet(
mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())),
mlirDenseElementsAttrGet(
result_type,
cst_table.len().try_into().unwrap(),
cst_table_attrs.as_ptr(),
),
);
create_op(
context,
"arith.constant",
&[],
&[result_type],
&[value_attr],
true,
)
}
}
pub unsafe fn register_all_dialects(context: MlirContext) {
let registry = mlirDialectRegistryCreate();
mlirRegisterAllDialects(registry);
mlirContextAppendDialectRegistry(context, registry);
mlirContextLoadAllAvailableDialects(context);
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_function_type() {
unsafe {
let context = mlirContextCreate();
let func_type = mlirFunctionTypeGet(context, 0, std::ptr::null(), 0, std::ptr::null());
assert!(mlirTypeIsAFunction(func_type));
mlirContextDestroy(context);
}
}
#[test]
fn test_module_parsing() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
let module_string = "
module{
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
%1 = arith.addi %arg0, %arg1 : i64
return %1: i64
}
}";
let module_cstring = CString::new(module_string).unwrap();
let module_reference = mlirStringRefCreateFromCString(module_cstring.as_ptr());
let parsed_module = mlirModuleCreateParse(context, module_reference);
let parsed_func = mlirBlockGetFirstOperation(mlirModuleGetBody(parsed_module));
let func_type_str = CString::new("function_type").unwrap();
// just check that we do have a function here, which should be enough to know that parsing worked well
assert!(mlirTypeIsAFunction(mlirTypeAttrGetValue(
mlirOperationGetAttributeByName(
parsed_func,
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
)
)));
}
}
#[test]
fn test_module_creation() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// input/output types
let func_input_types = [
mlirIntegerTypeGet(context, 64),
mlirIntegerTypeGet(context, 64),
];
let func_output_types = [mlirIntegerTypeGet(context, 64)];
let func_op = create_func_with_block(
context,
"test",
func_input_types.as_slice(),
func_output_types.as_slice(),
);
let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
let func_args = [
mlirBlockGetArgument(func_block, 0),
mlirBlockGetArgument(func_block, 1),
];
// create addi operation and append it to the block
let addi_op = create_addi_op(context, func_args[0], func_args[1]);
mlirBlockAppendOwnedOperation(func_block, addi_op);
// create ret operation and append it to the block
let ret_op = create_ret_op(context, mlirOperationGetResult(addi_op, 0));
mlirBlockAppendOwnedOperation(func_block, ret_op);
// create module to hold the previously created function
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
let printed_module =
super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
let expected_module = "\
module {
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
%0 = arith.addi %arg0, %arg1 : i64
return %0 : i64
}
}
";
assert_eq!(
printed_module, expected_module,
"left: \n{}, right: \n{}",
printed_module, expected_module
);
}
}
#[test]
fn test_constant_flat_tensor() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// create a constant flat tensor
let contant_flat_tensor_op = create_constant_flat_tensor_op(context, &[0, 1, 2, 3], 64);
let printed_op = print_mlir_operation_to_string(contant_flat_tensor_op);
let expected_op = "%cst = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_constant_tensor() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// create a constant tensor
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0, 1, 2, 3], 64);
let printed_op = print_mlir_operation_to_string(contant_tensor_op);
let expected_op = "%cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_constant_tensor_with_signle_elem() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// create a constant tensor
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0], 7);
let printed_op = print_mlir_operation_to_string(contant_tensor_op);
let expected_op = "%cst = arith.constant dense<0> : tensor<2x2xi7>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_constant_int() {
unsafe {
let context = mlirContextCreate();
register_all_dialects(context);
// create a constant flat tensor
let contant_int_op = create_constant_int_op(context, 73, 10);
let printed_op = print_mlir_operation_to_string(contant_int_op);
let expected_op = "%c73_i10 = arith.constant 73 : i10\n";
assert_eq!(printed_op, expected_op);
}
}
}

View File

@@ -1,2 +0,0 @@
add_subdirectory(Dialect)
add_subdirectory(Support)

View File

@@ -1,3 +0,0 @@
add_subdirectory(FHE)
add_subdirectory(FHELinalg)
add_subdirectory(Tracing)

View File

@@ -1,11 +0,0 @@
set(LLVM_OPTIONAL_SOURCES FHE.cpp)
add_mlir_public_c_api_library(
CONCRETELANGCAPIFHE
FHE.cpp
DEPENDS
mlir-headers
LINK_LIBS
PUBLIC
MLIRCAPIIR
FHEDialect)

View File

@@ -1,76 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang-c/Dialect/FHE.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/StorageUniquerSupport.h"
using namespace mlir::concretelang::FHE;
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHE, fhe, FHEDialect)
//===----------------------------------------------------------------------===//
// Type API.
//===----------------------------------------------------------------------===//
template <typename T>
MlirTypeOrError IntegerTypeGetChecked(MlirContext ctx, unsigned width) {
MlirTypeOrError type = {{NULL}, false};
auto catchError = [&]() -> mlir::InFlightDiagnostic {
type.isError = true;
mlir::DiagnosticEngine &engine = unwrap(ctx)->getDiagEngine();
// The goal here is to make getChecked working, but we don't want the CAPI
// to stop execution due to an error, and leave the error handling logic to
// the user of the CAPI
return engine.emit(mlir::UnknownLoc::get(unwrap(ctx)),
mlir::DiagnosticSeverity::Warning);
};
T integerType = T::getChecked(catchError, unwrap(ctx), width);
if (type.isError) {
return type;
}
type.type = wrap(integerType);
return type;
}
bool fheTypeIsAnEncryptedIntegerType(MlirType type) {
return unwrap(type).isa<EncryptedUnsignedIntegerType>();
}
MlirTypeOrError fheEncryptedIntegerTypeGetChecked(MlirContext ctx,
unsigned width) {
return IntegerTypeGetChecked<EncryptedUnsignedIntegerType>(ctx, width);
}
bool fheTypeIsAnEncryptedSignedIntegerType(MlirType type) {
return unwrap(type).isa<EncryptedSignedIntegerType>();
}
MlirTypeOrError fheEncryptedSignedIntegerTypeGetChecked(MlirContext ctx,
unsigned width) {
return IntegerTypeGetChecked<EncryptedSignedIntegerType>(ctx, width);
}
unsigned fheTypeIntegerWidthGet(MlirType integerType) {
mlir::Type type = unwrap(integerType);
auto eint = type.dyn_cast_or_null<EncryptedUnsignedIntegerType>();
if (eint) {
return eint.getWidth();
}
auto esint = type.dyn_cast_or_null<EncryptedSignedIntegerType>();
if (esint) {
return esint.getWidth();
}
return 0;
}

View File

@@ -1,11 +0,0 @@
set(LLVM_OPTIONAL_SOURCES FHELinalg.cpp)
add_mlir_public_c_api_library(
CONCRETELANGCAPIFHELINALG
FHELinalg.cpp
DEPENDS
mlir-headers
LINK_LIBS
PUBLIC
MLIRCAPIIR
FHELinalgDialect)

View File

@@ -1,20 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang-c/Dialect/FHELinalg.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Support.h"
using namespace mlir::concretelang::FHELinalg;
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg, FHELinalgDialect)

View File

@@ -1,11 +0,0 @@
set(LLVM_OPTIONAL_SOURCES Tracing.cpp)
add_mlir_public_c_api_library(
CONCRETELANGCAPITRACING
Tracing.cpp
DEPENDS
mlir-headers
LINK_LIBS
PUBLIC
MLIRCAPIIR
TracingDialect)

View File

@@ -1,19 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang-c/Dialect/Tracing.h"
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Registration.h"
#include "mlir/CAPI/Support.h"
using namespace mlir::concretelang::Tracing;
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Tracing, tracing, TracingDialect)

View File

@@ -1,4 +0,0 @@
set(LLVM_OPTIONAL_SOURCES CompilerEngine.cpp)
add_mlir_public_c_api_library(CONCRETELANGCAPISupport CompilerEngine.cpp LINK_LIBS PUBLIC MLIRCAPIIR
ConcretelangSupport)

View File

@@ -1,799 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/CAPI/Wrappers.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/LambdaArgument.h"
#include "concretelang/Support/LambdaSupport.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/Support/SourceMgr.h"
#include <numeric>
#define C_STRUCT_CLEANER(c_struct) \
auto *cpp = unwrap(c_struct); \
if (cpp != NULL) \
delete cpp; \
const char *error = getErrorPtr(c_struct); \
if (error != NULL) \
delete[] error;
/// ********** BufferRef CAPI **************************************************
BufferRef bufferRefCreate(const char *buffer, size_t length) {
return BufferRef{buffer, length, NULL};
}
BufferRef bufferRefFromString(std::string str) {
char *buffer = new char[str.size()];
memcpy(buffer, str.c_str(), str.size());
return bufferRefCreate(buffer, str.size());
}
BufferRef bufferRefFromStringError(std::string error) {
char *buffer = new char[error.size()];
memcpy(buffer, error.c_str(), error.size());
return BufferRef{NULL, 0, buffer};
}
void bufferRefDestroy(BufferRef buffer) {
if (buffer.data != NULL)
delete[] buffer.data;
if (buffer.error != NULL)
delete[] buffer.error;
}
/// ********** Utilities *******************************************************
void mlirStringRefDestroy(MlirStringRef str) { delete[] str.data; }
template <typename T> BufferRef serialize(T toSerialize) {
std::ostringstream ostream(std::ios::binary);
auto voidOrError = unwrap(toSerialize)->serialize(ostream);
if (voidOrError.has_error()) {
return bufferRefFromStringError(voidOrError.error().mesg);
}
return bufferRefFromString(ostream.str());
}
/// ********** CompilationOptions CAPI *****************************************
CompilationOptions
compilationOptionsCreate(MlirStringRef funcName, bool autoParallelize,
bool batchTFHEOps, bool dataflowParallelize,
bool emitGPUOps, bool loopParallelize,
bool optimizeTFHE, OptimizerConfig optimizerConfig,
bool verifyDiagnostics) {
std::string funcNameStr(funcName.data, funcName.length);
auto options = new mlir::concretelang::CompilationOptions(funcNameStr);
options->autoParallelize = autoParallelize;
options->batchTFHEOps = batchTFHEOps;
options->dataflowParallelize = dataflowParallelize;
options->emitGPUOps = emitGPUOps;
options->loopParallelize = loopParallelize;
options->optimizeTFHE = optimizeTFHE;
options->optimizerConfig = *unwrap(optimizerConfig);
options->verifyDiagnostics = verifyDiagnostics;
return wrap(options);
}
CompilationOptions compilationOptionsCreateDefault() {
return wrap(new mlir::concretelang::CompilationOptions("main"));
}
void compilationOptionsDestroy(CompilationOptions options){
C_STRUCT_CLEANER(options)}
/// ********** OptimizerConfig CAPI ********************************************
OptimizerConfig
optimizerConfigCreate(bool display, double fallback_log_norm_woppbs,
double global_p_error, double p_error,
uint64_t security,
mlir::concretelang::optimizer::Strategy strategy,
bool use_gpu_constraints,
uint32_t ciphertext_modulus_log,
uint32_t fft_precision) {
auto config = new mlir::concretelang::optimizer::Config();
config->display = display;
config->fallback_log_norm_woppbs = fallback_log_norm_woppbs;
config->global_p_error = global_p_error;
config->p_error = p_error;
config->security = security;
config->strategy = strategy;
config->use_gpu_constraints = use_gpu_constraints;
config->ciphertext_modulus_log = ciphertext_modulus_log;
config->fft_precision = fft_precision;
return wrap(config);
}
OptimizerConfig optimizerConfigCreateDefault() {
return wrap(new mlir::concretelang::optimizer::Config());
}
void optimizerConfigDestroy(OptimizerConfig config){C_STRUCT_CLEANER(config)}
/// ********** CompilerEngine CAPI *********************************************
CompilerEngine compilerEngineCreate() {
auto *engine = new mlir::concretelang::CompilerEngine(
mlir::concretelang::CompilationContext::createShared());
return wrap(engine);
}
void compilerEngineDestroy(CompilerEngine engine){C_STRUCT_CLEANER(engine)}
/// Map C compilationTarget to Cpp
llvm::Expected<mlir::concretelang::CompilerEngine::
Target> targetConvertToCppFromC(CompilationTarget target) {
switch (target) {
case ROUND_TRIP:
return mlir::concretelang::CompilerEngine::Target::ROUND_TRIP;
case FHE:
return mlir::concretelang::CompilerEngine::Target::FHE;
case TFHE:
return mlir::concretelang::CompilerEngine::Target::TFHE;
case PARAMETRIZED_TFHE:
return mlir::concretelang::CompilerEngine::Target::PARAMETRIZED_TFHE;
case NORMALIZED_TFHE:
return mlir::concretelang::CompilerEngine::Target::NORMALIZED_TFHE;
case BATCHED_TFHE:
return mlir::concretelang::CompilerEngine::Target::BATCHED_TFHE;
case CONCRETE:
return mlir::concretelang::CompilerEngine::Target::CONCRETE;
case STD:
return mlir::concretelang::CompilerEngine::Target::STD;
case LLVM:
return mlir::concretelang::CompilerEngine::Target::LLVM;
case LLVM_IR:
return mlir::concretelang::CompilerEngine::Target::LLVM_IR;
case OPTIMIZED_LLVM_IR:
return mlir::concretelang::CompilerEngine::Target::OPTIMIZED_LLVM_IR;
case LIBRARY:
return mlir::concretelang::CompilerEngine::Target::LIBRARY;
}
return mlir::concretelang::StreamStringError("invalid compilation target");
}
CompilationResult compilerEngineCompile(CompilerEngine engine,
MlirStringRef module,
CompilationTarget target) {
std::string module_str(module.data, module.length);
auto targetCppOrError = targetConvertToCppFromC(target);
if (!targetCppOrError) { // invalid target
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL,
llvm::toString(targetCppOrError.takeError()));
}
auto retOrError = unwrap(engine)->compile(module_str, targetCppOrError.get());
if (!retOrError) { // compilation error
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL,
llvm::toString(retOrError.takeError()));
}
return wrap(new mlir::concretelang::CompilerEngine::CompilationResult(
std::move(retOrError.get())));
}
void compilerEngineCompileSetOptions(CompilerEngine engine,
CompilationOptions options) {
unwrap(engine)->setCompilationOptions(*unwrap(options));
}
/// ********** CompilationResult CAPI ******************************************
MlirStringRef compilationResultGetModuleString(CompilationResult result) {
// print the module into a string
std::string moduleString;
llvm::raw_string_ostream os(moduleString);
unwrap(result)->mlirModuleRef->get().print(os);
// allocate buffer and copy module string
char *buffer = new char[moduleString.length() + 1];
strcpy(buffer, moduleString.c_str());
return mlirStringRefCreate(buffer, moduleString.length());
}
void compilationResultDestroyModuleString(MlirStringRef str) {
mlirStringRefDestroy(str);
}
void compilationResultDestroy(CompilationResult result){
C_STRUCT_CLEANER(result)}
/// ********** Library CAPI ****************************************************
Library libraryCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool cleanUp) {
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
runtimeLibraryPath.length);
return wrap(new mlir::concretelang::CompilerEngine::Library(
outputDirPathStr, runtimeLibraryPathStr, cleanUp));
}
void libraryDestroy(Library lib) { C_STRUCT_CLEANER(lib) }
/// ********** LibraryCompilationResult CAPI ***********************************
void libraryCompilationResultDestroy(LibraryCompilationResult result){
C_STRUCT_CLEANER(result)}
/// ********** LibrarySupport CAPI *********************************************
LibrarySupport
librarySupportCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath,
bool generateSharedLib, bool generateStaticLib,
bool generateClientParameters,
bool generateCompilationFeedback,
bool generateCppHeader) {
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
runtimeLibraryPath.length);
return wrap(new mlir::concretelang::LibrarySupport(
outputDirPathStr, runtimeLibraryPathStr, generateSharedLib,
generateStaticLib, generateClientParameters, generateCompilationFeedback,
generateCppHeader));
}
LibraryCompilationResult librarySupportCompile(LibrarySupport support,
MlirStringRef module,
CompilationOptions options) {
std::string moduleStr(module.data, module.length);
auto retOrError = unwrap(support)->compile(moduleStr, *unwrap(options));
if (!retOrError) {
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL,
llvm::toString(retOrError.takeError()));
}
return wrap(new mlir::concretelang::LibraryCompilationResult(
*retOrError.get().release()));
}
ServerLambda librarySupportLoadServerLambda(LibrarySupport support,
LibraryCompilationResult result) {
auto serverLambdaOrError = unwrap(support)->loadServerLambda(*unwrap(result));
if (!serverLambdaOrError) {
return wrap((mlir::concretelang::serverlib::ServerLambda *)NULL,
llvm::toString(serverLambdaOrError.takeError()));
}
return wrap(new mlir::concretelang::serverlib::ServerLambda(
serverLambdaOrError.get()));
}
ClientParameters
librarySupportLoadClientParameters(LibrarySupport support,
LibraryCompilationResult result) {
auto paramsOrError = unwrap(support)->loadClientParameters(*unwrap(result));
if (!paramsOrError) {
return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL,
llvm::toString(paramsOrError.takeError()));
}
return wrap(
new mlir::concretelang::clientlib::ClientParameters(paramsOrError.get()));
}
LibraryCompilationResult
librarySupportLoadCompilationResult(LibrarySupport support) {
auto retOrError = unwrap(support)->loadCompilationResult();
if (!retOrError) {
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL,
llvm::toString(retOrError.takeError()));
}
return wrap(new mlir::concretelang::LibraryCompilationResult(
*retOrError.get().release()));
}
CompilationFeedback
librarySupportLoadCompilationFeedback(LibrarySupport support,
LibraryCompilationResult result) {
auto feedbackOrError =
unwrap(support)->loadCompilationFeedback(*unwrap(result));
if (!feedbackOrError) {
return wrap((mlir::concretelang::CompilationFeedback *)NULL,
llvm::toString(feedbackOrError.takeError()));
}
return wrap(
new mlir::concretelang::CompilationFeedback(feedbackOrError.get()));
}
PublicResult librarySupportServerCall(LibrarySupport support,
ServerLambda server_lambda,
PublicArguments args,
EvaluationKeys evalKeys) {
auto resultOrError = unwrap(support)->serverCall(
*unwrap(server_lambda), *unwrap(args), *unwrap(evalKeys));
if (!resultOrError) {
return wrap((mlir::concretelang::clientlib::PublicResult *)NULL,
llvm::toString(resultOrError.takeError()));
}
return wrap(resultOrError.get().release());
}
MlirStringRef librarySupportGetSharedLibPath(LibrarySupport support) {
auto path = unwrap(support)->getSharedLibPath();
// allocate buffer and copy module string
char *buffer = new char[path.length() + 1];
strcpy(buffer, path.c_str());
return mlirStringRefCreate(buffer, path.length());
}
MlirStringRef librarySupportGetClientParametersPath(LibrarySupport support) {
auto path = unwrap(support)->getClientParametersPath();
// allocate buffer and copy module string
char *buffer = new char[path.length() + 1];
strcpy(buffer, path.c_str());
return mlirStringRefCreate(buffer, path.length());
}
void librarySupportDestroy(LibrarySupport support) { C_STRUCT_CLEANER(support) }
/// ********** ServerLamda CAPI ************************************************
void serverLambdaDestroy(ServerLambda server){C_STRUCT_CLEANER(server)}
/// ********** ClientParameters CAPI *******************************************
BufferRef clientParametersSerialize(ClientParameters params) {
llvm::json::Value value(*unwrap(params));
std::string jsonParams;
llvm::raw_string_ostream ostream(jsonParams);
ostream << value;
char *buffer = new char[jsonParams.size() + 1];
strcpy(buffer, jsonParams.c_str());
return bufferRefCreate(buffer, jsonParams.size());
}
ClientParameters clientParametersUnserialize(BufferRef buffer) {
std::string json(buffer.data, buffer.length);
auto paramsOrError =
llvm::json::parse<mlir::concretelang::ClientParameters>(json);
if (!paramsOrError) {
return wrap((mlir::concretelang::ClientParameters *)NULL,
llvm::toString(paramsOrError.takeError()));
}
return wrap(new mlir::concretelang::ClientParameters(paramsOrError.get()));
}
ClientParameters clientParametersCopy(ClientParameters params) {
return wrap(new mlir::concretelang::ClientParameters(*unwrap(params)));
}
void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)}
size_t clientParametersOutputsSize(ClientParameters params) {
return unwrap(params)->outputs.size();
}
size_t clientParametersInputsSize(ClientParameters params) {
return unwrap(params)->inputs.size();
}
CircuitGate clientParametersOutputCircuitGate(ClientParameters params,
size_t index) {
auto &cppGate = unwrap(params)->outputs[index];
auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate);
return wrap(cppGateCopy);
}
CircuitGate clientParametersInputCircuitGate(ClientParameters params,
size_t index) {
auto &cppGate = unwrap(params)->inputs[index];
auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate);
return wrap(cppGateCopy);
}
EncryptionGate circuitGateEncryptionGate(CircuitGate circuit_gate) {
auto &maybe_gate = unwrap(circuit_gate)->encryption;
if (maybe_gate) {
auto *copy = new mlir::concretelang::clientlib::EncryptionGate(*maybe_gate);
return wrap(copy);
}
return (static_cast<EncryptionGate (*)(
mlir::concretelang::clientlib::EncryptionGate *)>(wrap))(nullptr);
}
double encryptionGateVariance(EncryptionGate encryption_gate) {
return unwrap(encryption_gate)->variance;
}
Encoding encryptionGateEncoding(EncryptionGate encryption_gate) {
auto &cppEncoding = unwrap(encryption_gate)->encoding;
auto *copy = new mlir::concretelang::clientlib::Encoding(cppEncoding);
return wrap(copy);
}
uint64_t encodingPrecision(Encoding encoding) {
return unwrap(encoding)->precision;
}
void circuitGateDestroy(CircuitGate gate) { C_STRUCT_CLEANER(gate) }
void encryptionGateDestroy(EncryptionGate gate) { C_STRUCT_CLEANER(gate) }
void encodingDestroy(Encoding encoding){C_STRUCT_CLEANER(encoding)}
/// ********** KeySet CAPI *****************************************************
KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,
uint64_t seed_lsb) {
__uint128_t seed = seed_msb;
seed <<= 64;
seed += seed_lsb;
auto csprng = concretelang::clientlib::ConcreteCSPRNG(seed);
auto keySet = mlir::concretelang::clientlib::KeySet::generate(
*unwrap(params), std::move(csprng));
if (keySet.has_error()) {
return wrap((mlir::concretelang::clientlib::KeySet *)NULL,
keySet.error().mesg);
}
return wrap(keySet.value().release());
}
EvaluationKeys keySetGetEvaluationKeys(KeySet keySet) {
return wrap(new mlir::concretelang::clientlib::EvaluationKeys(
unwrap(keySet)->evaluationKeys()));
}
void keySetDestroy(KeySet keySet){C_STRUCT_CLEANER(keySet)}
/// ********** KeySetCache CAPI ************************************************
KeySetCache keySetCacheCreate(MlirStringRef cachePath) {
std::string cachePathStr(cachePath.data, cachePath.length);
return wrap(new mlir::concretelang::clientlib::KeySetCache(cachePathStr));
}
KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache,
ClientParameters params,
uint64_t seed_msb, uint64_t seed_lsb) {
auto keySetOrError =
unwrap(cache)->generate(*unwrap(params), seed_msb, seed_lsb);
if (keySetOrError.has_error()) {
return wrap((mlir::concretelang::clientlib::KeySet *)NULL,
keySetOrError.error().mesg);
}
return wrap(keySetOrError.value().release());
}
void keySetCacheDestroy(KeySetCache keySetCache){C_STRUCT_CLEANER(keySetCache)}
/// ********** EvaluationKeys CAPI *********************************************
BufferRef evaluationKeysSerialize(EvaluationKeys keys) {
std::ostringstream ostream(std::ios::binary);
concretelang::clientlib::operator<<(ostream, *unwrap(keys));
if (ostream.fail()) {
return bufferRefFromStringError(
"output stream failure during evaluation keys serialization");
}
return bufferRefFromString(ostream.str());
}
EvaluationKeys evaluationKeysUnserialize(BufferRef buffer) {
std::stringstream istream(std::string(buffer.data, buffer.length));
concretelang::clientlib::EvaluationKeys evaluationKeys =
concretelang::clientlib::readEvaluationKeys(istream);
if (istream.fail()) {
return wrap((concretelang::clientlib::EvaluationKeys *)NULL,
"input stream failure during evaluation keys unserialization");
}
return wrap(new concretelang::clientlib::EvaluationKeys(evaluationKeys));
}
void evaluationKeysDestroy(EvaluationKeys evaluationKeys) {
C_STRUCT_CLEANER(evaluationKeys);
}
/// ********** LambdaArgument CAPI *********************************************
LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
}
int64_t getSizeFromRankAndDims(size_t rank, const int64_t *dims) {
if (rank == 0) // not a tensor
return 1;
auto size = dims[0];
for (size_t i = 1; i < rank; i++)
size *= dims[i];
return size;
}
LambdaArgument lambdaArgumentFromTensorU8(const uint8_t *data,
const int64_t *dims, size_t rank) {
std::vector<uint8_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
std::vector<int64_t> dims_vector(dims, dims + rank);
return wrap(new mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>(data_vector,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU16(const uint16_t *data,
const int64_t *dims, size_t rank) {
std::vector<uint16_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
std::vector<int64_t> dims_vector(dims, dims + rank);
return wrap(new mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>(data_vector,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU32(const uint32_t *data,
const int64_t *dims, size_t rank) {
std::vector<uint32_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
std::vector<int64_t> dims_vector(dims, dims + rank);
return wrap(new mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>(data_vector,
dims_vector));
}
LambdaArgument lambdaArgumentFromTensorU64(const uint64_t *data,
const int64_t *dims, size_t rank) {
std::vector<uint64_t> data_vector(data,
data + getSizeFromRankAndDims(rank, dims));
std::vector<int64_t> dims_vector(dims, dims + rank);
return wrap(new mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>(data_vector,
dims_vector));
}
bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) {
return unwrap(lambdaArg)
->isa<mlir::concretelang::IntLambdaArgument<uint64_t>>();
}
uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg) {
mlir::concretelang::IntLambdaArgument<uint64_t> *arg =
unwrap(lambdaArg)
->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
assert(arg != nullptr && "lambda argument isn't a scalar");
return arg->getValue();
}
bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) {
return unwrap(lambdaArg)
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
unwrap(lambdaArg)
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
unwrap(lambdaArg)
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
unwrap(lambdaArg)
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
}
template <typename T>
bool copyTensorDataToBuffer(
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>> *tensor,
uint64_t *buffer) {
auto *data = tensor->getValue();
auto sizeOrError = tensor->getNumElements();
if (!sizeOrError) {
llvm::errs() << llvm::toString(sizeOrError.takeError());
return false;
}
auto size = sizeOrError.get();
for (size_t i = 0; i < size; i++)
buffer[i] = data[i];
return true;
}
bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg, uint64_t *buffer) {
auto arg = unwrap(lambdaArg);
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return copyTensorDataToBuffer(tensor, buffer);
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return copyTensorDataToBuffer(tensor, buffer);
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return copyTensorDataToBuffer(tensor, buffer);
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
return copyTensorDataToBuffer(tensor, buffer);
}
return false;
}
size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg) {
auto arg = unwrap(lambdaArg);
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return tensor->getDimensions().size();
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return tensor->getDimensions().size();
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return tensor->getDimensions().size();
}
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
return tensor->getDimensions().size();
}
return 0;
}
int64_t lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg) {
auto arg = unwrap(lambdaArg);
std::vector<int64_t> dims;
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
dims = tensor->getDimensions();
} else {
return 0;
}
return std::accumulate(std::begin(dims), std::end(dims), 1,
std::multiplies<int64_t>());
}
bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg, int64_t *buffer) {
auto arg = unwrap(lambdaArg);
std::vector<int64_t> dims;
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
dims = tensor->getDimensions();
} else if (auto tensor =
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
dims = tensor->getDimensions();
} else {
return false;
}
memcpy(buffer, dims.data(), sizeof(int64_t) * dims.size());
return true;
}
PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
size_t argNumber, ClientParameters params,
KeySet keySet) {
std::vector<const mlir::concretelang::LambdaArgument *> args;
for (size_t i = 0; i < argNumber; i++)
args.push_back(unwrap(lambdaArgs[i]));
auto publicArgsOrError =
mlir::concretelang::LambdaSupport<int, int>::exportArguments(
*unwrap(params), *unwrap(keySet), args);
if (!publicArgsOrError) {
return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL,
llvm::toString(publicArgsOrError.takeError()));
}
return wrap(publicArgsOrError.get().release());
}
void lambdaArgumentDestroy(LambdaArgument lambdaArg){
C_STRUCT_CLEANER(lambdaArg)}
/// ********** PublicArguments CAPI ********************************************
BufferRef publicArgumentsSerialize(PublicArguments args) {
return serialize(args);
}
PublicArguments publicArgumentsUnserialize(BufferRef buffer,
ClientParameters params) {
std::stringstream istream(std::string(buffer.data, buffer.length));
auto argsOrError = concretelang::clientlib::PublicArguments::unserialize(
*unwrap(params), istream);
if (!argsOrError) {
return wrap((concretelang::clientlib::PublicArguments *)NULL,
argsOrError.error().mesg);
}
return wrap(argsOrError.value().release());
}
void publicArgumentsDestroy(PublicArguments publicArgs){
C_STRUCT_CLEANER(publicArgs)}
/// ********** PublicResult CAPI ***********************************************
LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) {
llvm::Expected<std::unique_ptr<mlir::concretelang::LambdaArgument>>
lambdaArgOrError = mlir::concretelang::typedResult<
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
*unwrap(keySet), *unwrap(publicResult));
if (!lambdaArgOrError) {
return wrap((mlir::concretelang::LambdaArgument *)NULL,
llvm::toString(lambdaArgOrError.takeError()));
}
return wrap(lambdaArgOrError.get().release());
}
BufferRef publicResultSerialize(PublicResult result) {
return serialize(result);
}
PublicResult publicResultUnserialize(BufferRef buffer,
ClientParameters params) {
std::stringstream istream(std::string(buffer.data, buffer.length));
auto resultOrError = concretelang::clientlib::PublicResult::unserialize(
*unwrap(params), istream);
if (!resultOrError) {
return wrap((concretelang::clientlib::PublicResult *)NULL,
resultOrError.error().mesg);
}
return wrap(resultOrError.value().release());
}
void publicResultDestroy(PublicResult publicResult) {
C_STRUCT_CLEANER(publicResult)
}
/// ********** CompilationFeedback CAPI ****************************************
double compilationFeedbackGetComplexity(CompilationFeedback feedback) {
return unwrap(feedback)->complexity;
}
double compilationFeedbackGetPError(CompilationFeedback feedback) {
return unwrap(feedback)->pError;
}
double compilationFeedbackGetGlobalPError(CompilationFeedback feedback) {
return unwrap(feedback)->globalPError;
}
uint64_t
compilationFeedbackGetTotalSecretKeysSize(CompilationFeedback feedback) {
return unwrap(feedback)->totalSecretKeysSize;
}
uint64_t
compilationFeedbackGetTotalBootstrapKeysSize(CompilationFeedback feedback) {
return unwrap(feedback)->totalBootstrapKeysSize;
}
uint64_t
compilationFeedbackGetTotalKeyswitchKeysSize(CompilationFeedback feedback) {
return unwrap(feedback)->totalKeyswitchKeysSize;
}
uint64_t compilationFeedbackGetTotalInputsSize(CompilationFeedback feedback) {
return unwrap(feedback)->totalInputsSize;
}
uint64_t compilationFeedbackGetTotalOutputsSize(CompilationFeedback feedback) {
return unwrap(feedback)->totalOutputsSize;
}
void compilationFeedbackDestroy(CompilationFeedback feedback) {
C_STRUCT_CLEANER(feedback)
}

View File

@@ -1,5 +1,6 @@
add_subdirectory(Analysis)
add_subdirectory(Dialect)
add_subdirectory(Common)
add_subdirectory(Conversion)
add_subdirectory(Transforms)
add_subdirectory(Support)
@@ -8,4 +9,3 @@ add_subdirectory(ClientLib)
add_subdirectory(Bindings)
add_subdirectory(ServerLib)
add_subdirectory(Interfaces)
add_subdirectory(CAPI)

View File

@@ -1,18 +1,15 @@
add_compile_options(-fexceptions)
add_mlir_library(
ConcretelangClientLib
ClientLambda.cpp
ClientParameters.cpp
EvaluationKeys.cpp
CRT.cpp
EncryptedArguments.cpp
KeySet.cpp
KeySetCache.cpp
PublicArguments.cpp
Serializers.cpp
ClientLib.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
${PROJECT_SOURCE_DIR}/include/concretelang/Common
LINK_LIBS
ConcretelangCommon
PUBLIC
concrete_cpu)
concrete_cpu
concrete-protocol)
target_include_directories(ConcretelangClientLib PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})

View File

@@ -1,185 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <dlfcn.h>
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/ClientLib/Serializers.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
outcome::checked<ClientLambda, StringError>
ClientLambda::load(std::string functionName, std::string jsonPath) {
OUTCOME_TRY(auto all_params, ClientParameters::load(jsonPath));
auto param = llvm::find_if(all_params, [&](ClientParameters param) {
return param.functionName == functionName;
});
if (param == all_params.end()) {
return StringError("ClientLambda: cannot find function ")
<< functionName << " in client parameters" << jsonPath;
}
if (param->outputs.size() != 1) {
return StringError("ClientLambda: output arity (")
<< std::to_string(param->outputs.size()) << ") != 1 is not supprted";
}
if (!param->outputs[0].encryption.has_value()) {
return StringError("ClientLambda: clear output is not yet supported");
}
ClientLambda lambda;
lambda.clientParameters = *param;
return lambda;
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
ClientLambda::keySet(std::shared_ptr<KeySetCache> optionalCache,
uint64_t seed_msb, uint64_t seed_lsb) {
return KeySetCache::generate(optionalCache, clientParameters, seed_msb,
seed_lsb);
}
outcome::checked<decrypted_scalar_t, StringError>
ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, result));
return v[0];
}
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
return result.asClearTextVector<decrypted_scalar_t>(keySet, 0);
}
outcome::checked<void, StringError> errorResultRank(size_t expected,
size_t actual) {
return StringError("Expected result has rank ")
<< expected << " and cannot be converted to rank " << actual;
}
StringError errorIncoherentSizes(size_t flatSize, size_t structuredSize) {
return StringError("Received ")
<< flatSize << " values but is sizes indicates as global size of "
<< structuredSize;
}
template <typename DecryptedTensor>
DecryptedTensor flatToTensor(decrypted_tensor_1_t &values, size_t *sizes);
template <>
decrypted_tensor_1_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
return values;
}
template <>
decrypted_tensor_2_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
decrypted_tensor_2_t result(sizes[0]);
size_t position = 0;
for (auto &dest0 : result) {
dest0.resize(sizes[1]);
for (auto &dest1 : dest0) {
dest1 = values[position++];
}
}
return result;
}
template <>
decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
decrypted_tensor_3_t result(sizes[0]);
size_t position = 0;
for (auto &dest0 : result) {
dest0.resize(sizes[1]);
for (auto &dest1 : dest0) {
dest1.resize(sizes[2]);
for (auto &dest2 : dest1) {
dest2 = values[position++];
}
}
}
return result;
}
template <typename DecryptedTensor>
outcome::checked<DecryptedTensor, StringError>
decryptReturnedTensor(PublicResult &result, ClientLambda &lambda,
ClientParameters &params, size_t expectedRank,
KeySet &keySet) {
auto shape = params.outputs[0].shape;
size_t rank = shape.dimensions.size();
if (rank != expectedRank) {
return StringError("Function returns a tensor of rank ")
<< expectedRank << " which cannot be decrypted to rank " << rank;
}
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, result));
llvm::SmallVector<size_t, 6> sizes;
for (size_t dim = 0; dim < rank; dim++) {
sizes.push_back(shape.dimensions[dim]);
}
return flatToTensor<DecryptedTensor>(values, sizes.data());
}
outcome::checked<decrypted_tensor_1_t, StringError>
ClientLambda::decryptReturnedTensor1(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_1_t>(
result, *this, this->clientParameters, 1, keySet);
}
outcome::checked<decrypted_tensor_2_t, StringError>
ClientLambda::decryptReturnedTensor2(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_2_t>(
result, *this, this->clientParameters, 2, keySet);
}
outcome::checked<decrypted_tensor_3_t, StringError>
ClientLambda::decryptReturnedTensor3(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_3_t>(
result, *this, this->clientParameters, 3, keySet);
}
template <typename Result>
outcome::checked<Result, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
PublicResult &result) {
// compile time error if used
using COMPATIBLE_RESULT_TYPE = void;
return (Result)(COMPATIBLE_RESULT_TYPE)0;
}
template <>
outcome::checked<decrypted_scalar_t, StringError>
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
PublicResult &result) {
return lambda.decryptReturnedScalar(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_1_t, StringError>
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
KeySet &keySet,
PublicResult &result) {
return lambda.decryptReturnedTensor1(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_2_t, StringError>
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
KeySet &keySet,
PublicResult &result) {
return lambda.decryptReturnedTensor2(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_3_t, StringError>
topLevelDecryptResult<decrypted_tensor_3_t>(ClientLambda &lambda,
KeySet &keySet,
PublicResult &result) {
return lambda.decryptReturnedTensor3(keySet, result);
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -0,0 +1,132 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <cassert>
#include <cstdint>
#include <cstring>
#include <functional>
#include <optional>
#include <string>
#include <variant>
#include "boost/outcome.h"
#include "concrete-cpu.h"
#include "concrete-protocol.capnp.h"
#include "concretelang/ClientLib/ClientLib.h"
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Common/Keysets.h"
#include "concretelang/Common/Protocol.h"
#include "concretelang/Common/Transformers.h"
#include "concretelang/Common/Values.h"
using concretelang::error::Result;
using concretelang::keysets::ClientKeyset;
using concretelang::transformers::InputTransformer;
using concretelang::transformers::OutputTransformer;
using concretelang::transformers::TransformerFactory;
using concretelang::values::TransportValue;
using concretelang::values::Value;
namespace concretelang {
namespace clientlib {
Result<ClientCircuit>
ClientCircuit::create(const Message<concreteprotocol::CircuitInfo> &info,
const ClientKeyset &keyset,
std::shared_ptr<CSPRNG> csprng, bool useSimulation) {
auto inputTransformers = std::vector<InputTransformer>();
for (auto gateInfo : info.asReader().getInputs()) {
InputTransformer transformer;
if (gateInfo.getTypeInfo().hasIndex()) {
OUTCOME_TRY(transformer,
TransformerFactory::getIndexInputTransformer(gateInfo));
} else if (gateInfo.getTypeInfo().hasPlaintext()) {
OUTCOME_TRY(transformer,
TransformerFactory::getPlaintextInputTransformer(gateInfo));
} else if (gateInfo.getTypeInfo().hasLweCiphertext()) {
OUTCOME_TRY(transformer,
TransformerFactory::getLweCiphertextInputTransformer(
keyset, gateInfo, csprng, useSimulation));
} else {
return StringError("Malformed input gate info.");
}
inputTransformers.push_back(transformer);
}
auto outputTransformers = std::vector<OutputTransformer>();
for (auto gateInfo : info.asReader().getOutputs()) {
OutputTransformer transformer;
if (gateInfo.getTypeInfo().hasIndex()) {
OUTCOME_TRY(transformer,
TransformerFactory::getIndexOutputTransformer(gateInfo));
} else if (gateInfo.getTypeInfo().hasPlaintext()) {
OUTCOME_TRY(transformer,
TransformerFactory::getPlaintextOutputTransformer(gateInfo));
} else if (gateInfo.getTypeInfo().hasLweCiphertext()) {
OUTCOME_TRY(transformer,
TransformerFactory::getLweCiphertextOutputTransformer(
keyset, gateInfo, useSimulation));
} else {
return StringError("Malformed output gate info.");
}
outputTransformers.push_back(transformer);
}
return ClientCircuit(info, inputTransformers, outputTransformers);
}
Result<TransportValue> ClientCircuit::prepareInput(Value arg, size_t pos) {
if (pos >= inputTransformers.size()) {
return StringError("Tried to prepare a Value for incorrect position.");
}
return inputTransformers[pos](arg);
}
Result<Value> ClientCircuit::processOutput(TransportValue result, size_t pos) {
if (pos >= outputTransformers.size()) {
return StringError(
"Tried to process a TransportValue for incorrect position.");
}
return outputTransformers[pos](result);
}
std::string ClientCircuit::getName() {
return circuitInfo.asReader().getName();
}
const Message<concreteprotocol::CircuitInfo> &ClientCircuit::getCircuitInfo() {
return circuitInfo;
}
Result<ClientProgram>
ClientProgram::create(const Message<concreteprotocol::ProgramInfo> &info,
const ClientKeyset &keyset,
std::shared_ptr<CSPRNG> csprng, bool useSimulation) {
ClientProgram output;
for (auto circuitInfo : info.asReader().getCircuits()) {
OUTCOME_TRY(
ClientCircuit clientCircuit,
ClientCircuit::create(circuitInfo, keyset, csprng, useSimulation));
output.circuits.push_back(clientCircuit);
}
return output;
}
Result<ClientCircuit> ClientProgram::getClientCircuit(std::string circuitName) {
for (auto circuit : circuits) {
if (circuit.getName() == circuitName) {
return circuit;
}
}
return StringError("Tried to get unknown client circuit: `" + circuitName +
"`");
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -1,260 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <fstream>
#include "boost/outcome.h"
#include "llvm/ADT/Hashing.h"
#include "concretelang/ClientLib/ClientParameters.h"
namespace concretelang {
namespace clientlib {
using StringError = concretelang::error::StringError;
// https://stackoverflow.com/a/38140932
static inline void hash_(std::size_t &seed) {}
template <typename T, typename... Rest>
static inline void hash_(std::size_t &seed, const T &v, Rest... rest) {
// See https://softwareengineering.stackexchange.com/a/402543
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
seed ^= llvm::hash_value(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
hash_(seed, rest...);
}
static long double_to_bits(double &v) { return *reinterpret_cast<long *>(&v); }
void LweSecretKeyParam::hash(size_t &seed) { hash_(seed, dimension); }
void BootstrapKeyParam::hash(size_t &seed) {
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
glweDimension, double_to_bits(variance));
}
void KeyswitchKeyParam::hash(size_t &seed) {
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
double_to_bits(variance));
}
void PackingKeyswitchKeyParam::hash(size_t &seed) {
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
glweDimension, polynomialSize, inputLweDimension,
double_to_bits(variance));
}
std::size_t ClientParameters::hash() {
std::size_t currentHash = 1;
for (auto secretKeyParam : secretKeys) {
secretKeyParam.hash(currentHash);
}
for (auto bootstrapKeyParam : bootstrapKeys) {
bootstrapKeyParam.hash(currentHash);
}
for (auto keyswitchParam : keyswitchKeys) {
keyswitchParam.hash(currentHash);
}
for (auto packingKeyswitchKeyParam : packingKeyswitchKeys) {
packingKeyswitchKeyParam.hash(currentHash);
}
return currentHash;
}
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
llvm::json::Object object{
{"dimension", v.dimension},
};
return object;
}
bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("dimension", v.dimension);
}
llvm::json::Value toJSON(const BootstrapKeyParam &v) {
llvm::json::Object object{
{"inputSecretKeyID", v.inputSecretKeyID},
{"outputSecretKeyID", v.outputSecretKeyID},
{"level", v.level},
{"glweDimension", v.glweDimension},
{"baseLog", v.baseLog},
{"variance", v.variance},
{"polynomialSize", v.polynomialSize},
{"inputLweDimension", v.inputLweDimension},
};
return object;
}
bool fromJSON(const llvm::json::Value j, BootstrapKeyParam &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
O.map("glweDimension", v.glweDimension) &&
O.map("variance", v.variance) &&
O.map("polynomialSize", v.polynomialSize) &&
O.map("inputLweDimension", v.inputLweDimension);
}
llvm::json::Value toJSON(const KeyswitchKeyParam &v) {
llvm::json::Object object{
{"inputSecretKeyID", v.inputSecretKeyID},
{"outputSecretKeyID", v.outputSecretKeyID},
{"level", v.level},
{"baseLog", v.baseLog},
{"variance", v.variance},
};
return object;
}
bool fromJSON(const llvm::json::Value j, KeyswitchKeyParam &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
O.map("variance", v.variance);
}
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &v) {
llvm::json::Object object{
{"inputSecretKeyID", v.inputSecretKeyID},
{"outputSecretKeyID", v.outputSecretKeyID},
{"level", v.level},
{"baseLog", v.baseLog},
{"glweDimension", v.glweDimension},
{"polynomialSize", v.polynomialSize},
{"inputLweDimension", v.inputLweDimension},
{"variance", v.variance},
};
return object;
}
bool fromJSON(const llvm::json::Value j, PackingKeyswitchKeyParam &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
O.map("glweDimension", v.glweDimension) &&
O.map("polynomialSize", v.polynomialSize) &&
O.map("inputLweDimension", v.inputLweDimension) &&
O.map("variance", v.variance);
}
llvm::json::Value toJSON(const CircuitGateShape &v) {
llvm::json::Object object{
{"width", v.width},
{"dimensions", v.dimensions},
{"size", v.size},
{"sign", v.sign},
};
return object;
}
bool fromJSON(const llvm::json::Value j, CircuitGateShape &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("width", v.width) && O.map("size", v.size) &&
O.map("dimensions", v.dimensions) && O.map("sign", v.sign);
}
llvm::json::Value toJSON(const Encoding &v) {
llvm::json::Object object{
{"precision", v.precision},
{"isSigned", v.isSigned},
};
if (!v.crt.empty()) {
object.insert({"crt", v.crt});
}
return object;
}
bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
if (!(O && O.map("precision", v.precision) &&
O.map("isSigned", v.isSigned))) {
return false;
}
// TODO: check this is correct for an optional field
O.map("crt", v.crt);
return true;
}
llvm::json::Value toJSON(const EncryptionGate &v) {
llvm::json::Object object{
{"secretKeyID", v.secretKeyID},
{"variance", v.variance},
{"encoding", v.encoding},
};
return object;
}
bool fromJSON(const llvm::json::Value j, EncryptionGate &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("secretKeyID", v.secretKeyID) &&
O.map("variance", v.variance) && O.map("encoding", v.encoding);
}
llvm::json::Value toJSON(const CircuitGate &v) {
llvm::json::Object object{
{"encryption", v.encryption},
{"shape", v.shape},
};
return object;
}
bool fromJSON(const llvm::json::Value j, CircuitGate &v, llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("encryption", v.encryption) && O.map("shape", v.shape);
}
llvm::json::Value toJSON(const ClientParameters &v) {
llvm::json::Object object{
{"secretKeys", v.secretKeys},
{"bootstrapKeys", v.bootstrapKeys},
{"keyswitchKeys", v.keyswitchKeys},
{"packingKeyswitchKeys", v.packingKeyswitchKeys},
{"inputs", v.inputs},
{"outputs", v.outputs},
{"functionName", v.functionName},
};
return object;
}
bool fromJSON(const llvm::json::Value j, ClientParameters &v,
llvm::json::Path p) {
llvm::json::ObjectMapper O(j, p);
return O && O.map("secretKeys", v.secretKeys) &&
O.map("bootstrapKeys", v.bootstrapKeys) &&
O.map("keyswitchKeys", v.keyswitchKeys) &&
O.map("packingKeyswitchKeys", v.packingKeyswitchKeys) &&
O.map("inputs", v.inputs) && O.map("outputs", v.outputs) &&
O.map("functionName", v.functionName);
}
std::string ClientParameters::getClientParametersPath(std::string path) {
return path + CLIENT_PARAMETERS_EXT;
}
outcome::checked<std::vector<ClientParameters>, StringError>
ClientParameters::load(std::string jsonPath) {
std::ifstream file(jsonPath);
std::string content((std::istreambuf_iterator<char>(file)),
(std::istreambuf_iterator<char>()));
if (file.fail()) {
return StringError("Cannot read file: ") << jsonPath;
}
auto expectedClientParams =
llvm::json::parse<std::vector<ClientParameters>>(content);
if (auto err = expectedClientParams.takeError()) {
return StringError("Cannot open client parameters: ")
<< llvm::toString(std::move(err)) << "\n"
<< content << "\n";
}
return expectedClientParams.get();
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -1,63 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/PublicArguments.h"
namespace concretelang {
namespace clientlib {
using StringError = concretelang::error::StringError;
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
auto sharedValues = std::vector<SharedScalarOrTensorData>();
sharedValues.reserve(this->values.size());
for (auto &&value : this->values) {
sharedValues.push_back(SharedScalarOrTensorData(std::move(value)));
}
return std::make_unique<PublicArguments>(clientParameters, sharedValues);
}
/// Split the input integer into `size` chunks of `chunkWidth` bits each
std::vector<uint64_t> chunkInput(uint64_t value, size_t size,
unsigned int chunkWidth) {
std::vector<uint64_t> chunks;
chunks.reserve(size);
uint64_t mask = (1 << chunkWidth) - 1;
for (size_t i = 0; i < size; i++) {
auto chunk = value & mask;
chunks.push_back((uint64_t)chunk);
value >>= chunkWidth;
}
return chunks;
}
outcome::checked<void, StringError> checkSizes(size_t actualSize,
size_t expectedSize) {
if (actualSize == expectedSize) {
return outcome::success();
}
return StringError("function expects ")
<< expectedSize << " arguments but has been called with " << actualSize
<< " arguments";
}
outcome::checked<void, StringError>
EncryptedArguments::checkAllArgs(KeySet &keySet) {
size_t arity = keySet.numInputs();
return checkSizes(values.size(), arity);
}
outcome::checked<void, StringError>
EncryptedArguments::checkAllArgs(ClientParameters &params) {
size_t arity = params.inputs.size();
return checkSizes(values.size(), arity);
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -1,157 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concrete-cpu.h"
#include "concretelang/ClientLib/ClientParameters.h"
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
inline void getApproval() {
std::cerr << "DANGER: You are generating an empty unsecure secret keys. "
"Enter \"y\" to continue: ";
char answer;
std::cin >> answer;
if (answer != 'y') {
std::abort();
}
}
#endif
namespace concretelang {
namespace clientlib {
ConcreteCSPRNG::ConcreteCSPRNG(__uint128_t seed)
: CSPRNG(nullptr, &CONCRETE_CSPRNG_VTABLE) {
ptr = (Csprng *)aligned_alloc(CONCRETE_CSPRNG_ALIGN, CONCRETE_CSPRNG_SIZE);
struct Uint128 u128;
if (seed == 0) {
switch (concrete_cpu_crypto_secure_random_128(&u128)) {
case 1:
break;
case -1:
llvm::errs()
<< "WARNING: The generated random seed is not crypto secure\n";
break;
default:
assert(false && "Cannot instantiate a random seed");
}
} else {
for (int i = 0; i < 16; i++) {
u128.little_endian_bytes[i] = seed >> (8 * i);
}
}
concrete_cpu_construct_concrete_csprng(ptr, u128);
}
ConcreteCSPRNG::ConcreteCSPRNG(ConcreteCSPRNG &&other)
: CSPRNG(other.ptr, &CONCRETE_CSPRNG_VTABLE) {
assert(ptr != nullptr);
other.ptr = nullptr;
}
ConcreteCSPRNG::~ConcreteCSPRNG() {
if (ptr != nullptr) {
concrete_cpu_destroy_concrete_csprng(ptr);
free(ptr);
}
}
LweSecretKey::LweSecretKey(LweSecretKeyParam &parameters, CSPRNG &csprng)
: _parameters(parameters) {
// Allocate the buffer
_buffer = std::make_shared<std::vector<uint64_t>>();
_buffer->resize(parameters.dimension);
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
// In insecure debug mode, the secret key is filled with zeros.
getApproval();
for (uint64_t &val : *_buffer) {
val = 0;
}
#else
// Initialize the lwe secret key buffer
concrete_cpu_init_secret_key_u64(_buffer->data(), parameters.dimension,
csprng.ptr, csprng.vtable);
#endif
}
void LweSecretKey::encrypt(uint64_t *ciphertext, uint64_t plaintext,
double variance, CSPRNG &csprng) const {
concrete_cpu_encrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext,
plaintext, parameters().dimension,
variance, csprng.ptr, csprng.vtable);
}
void LweSecretKey::decrypt(const uint64_t *ciphertext,
uint64_t &plaintext) const {
concrete_cpu_decrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext,
parameters().dimension, &plaintext);
}
LweKeyswitchKey::LweKeyswitchKey(KeyswitchKeyParam &parameters,
LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng)
: _parameters(parameters) {
// Allocate the buffer
auto size = concrete_cpu_keyswitch_key_size_u64(
_parameters.level, _parameters.baseLog, inputKey.dimension(),
outputKey.dimension());
_buffer = std::make_shared<std::vector<uint64_t>>();
_buffer->resize(size);
// Initialize the keyswitch key buffer
concrete_cpu_init_lwe_keyswitch_key_u64(
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
inputKey.dimension(), outputKey.dimension(), _parameters.level,
_parameters.baseLog, _parameters.variance, csprng.ptr, csprng.vtable);
}
LweBootstrapKey::LweBootstrapKey(BootstrapKeyParam &parameters,
LweSecretKey &inputKey,
LweSecretKey &outputKey, CSPRNG &csprng)
: _parameters(parameters) {
// TODO
size_t polynomial_size = outputKey.dimension() / _parameters.glweDimension;
// Allocate the buffer
auto size = concrete_cpu_bootstrap_key_size_u64(
_parameters.level, _parameters.glweDimension, polynomial_size,
inputKey.dimension());
_buffer = std::make_shared<std::vector<uint64_t>>();
_buffer->resize(size);
// Initialize the keyswitch key buffer
concrete_cpu_init_lwe_bootstrap_key_u64(
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
inputKey.dimension(), polynomial_size, _parameters.glweDimension,
_parameters.level, _parameters.baseLog, _parameters.variance,
Parallelism::Rayon, csprng.ptr, csprng.vtable);
}
PackingKeyswitchKey::PackingKeyswitchKey(PackingKeyswitchKeyParam &params,
LweSecretKey &inputKey,
LweSecretKey &outputKey,
CSPRNG &csprng)
: _parameters(params) {
assert(_parameters.inputLweDimension == inputKey.dimension());
assert(_parameters.glweDimension * _parameters.polynomialSize ==
outputKey.dimension());
// Allocate the buffer
auto size = concrete_cpu_lwe_packing_keyswitch_key_size(
_parameters.glweDimension, _parameters.polynomialSize, _parameters.level,
_parameters.inputLweDimension);
_buffer = std::make_shared<std::vector<uint64_t>>();
_buffer->resize(size * (_parameters.glweDimension + 1));
// Initialize the keyswitch key buffer
concrete_cpu_init_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
_parameters.inputLweDimension, _parameters.polynomialSize,
_parameters.glweDimension, _parameters.level, _parameters.baseLog,
_parameters.variance, Parallelism::Rayon, csprng.ptr, csprng.vtable);
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -1,303 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/CRT.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Support/Error.h"
#include <cassert>
#include <cstddef>
#include <cstdint>
namespace concretelang {
namespace clientlib {
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySet::generate(ClientParameters clientParameters, CSPRNG &&csprng) {
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
OUTCOME_TRYV(keySet->generateKeysFromParams());
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
return std::move(keySet);
}
outcome::checked<std::unique_ptr<KeySet>, StringError> KeySet::fromKeys(
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
std::vector<LweBootstrapKey> bootstrapKeys,
std::vector<LweKeyswitchKey> keyswitchKeys,
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng) {
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
keySet->secretKeys = secretKeys;
keySet->bootstrapKeys = bootstrapKeys;
keySet->keyswitchKeys = keyswitchKeys;
keySet->packingKeyswitchKeys = packingKeyswitchKeys;
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
return std::move(keySet);
}
EvaluationKeys KeySet::evaluationKeys() {
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
}
outcome::checked<KeySet::SecretKeyGateMapping, StringError>
KeySet::mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates) {
SecretKeyGateMapping mapping;
for (auto gate : gates) {
if (gate.encryption.has_value()) {
assert(gate.encryption->secretKeyID < this->secretKeys.size());
auto skIt = this->secretKeys[gate.encryption->secretKeyID];
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {gate, skIt};
mapping.push_back(input);
} else {
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {
gate, std::nullopt};
mapping.push_back(input);
}
}
return mapping;
}
outcome::checked<void, StringError> KeySet::setupEncryptionMaterial() {
OUTCOME_TRY(this->inputs,
mapCircuitGateLweSecretKey(_clientParameters.inputs));
OUTCOME_TRY(this->outputs,
mapCircuitGateLweSecretKey(_clientParameters.outputs));
return outcome::success();
}
outcome::checked<void, StringError> KeySet::generateKeysFromParams() {
// Generate LWE secret keys
for (auto secretKeyParam : _clientParameters.secretKeys) {
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam));
}
// Generate bootstrap keys
for (auto bootstrapKeyParam : _clientParameters.bootstrapKeys) {
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam));
}
// Generate keyswitch key
for (auto keyswitchParam : _clientParameters.keyswitchKeys) {
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam));
}
// Generate packing keyswitch key
for (auto packingKeyswitchKeyParam : _clientParameters.packingKeyswitchKeys) {
OUTCOME_TRYV(this->generatePackingKeyswitchKey(packingKeyswitchKeyParam));
}
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generateSecretKey(LweSecretKeyParam param) {
// Init the lwe secret key
LweSecretKey sk(param, csprng);
// Store the lwe secret key
secretKeys.push_back(sk);
return outcome::success();
}
outcome::checked<LweSecretKey, StringError>
KeySet::findLweSecretKey(LweSecretKeyID keyID) {
assert(keyID < secretKeys.size());
auto secretKey = secretKeys[keyID];
return secretKey;
}
outcome::checked<void, StringError>
KeySet::generateBootstrapKey(BootstrapKeyParam param) {
// Finding input and output secretKeys
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
// Initialize the bootstrap key
LweBootstrapKey bootstrapKey(param, inputKey, outputKey, csprng);
// Store the bootstrap key
bootstrapKeys.push_back(std::move(bootstrapKey));
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generateKeyswitchKey(KeyswitchKeyParam param) {
// Finding input and output secretKeys
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
// Initialize the bootstrap key
LweKeyswitchKey keyswitchKey(param, inputKey, outputKey, csprng);
// Store the keyswitch key
keyswitchKeys.push_back(keyswitchKey);
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::generatePackingKeyswitchKey(PackingKeyswitchKeyParam param) {
// Finding input secretKeys
assert(param.inputSecretKeyID < secretKeys.size());
auto inputSk = secretKeys[param.inputSecretKeyID];
assert(param.outputSecretKeyID < secretKeys.size());
auto outputSk = secretKeys[param.outputSecretKeyID];
PackingKeyswitchKey packingKeyswitchKey(param, inputSk, outputSk, csprng);
// Store the keyswitch key
packingKeyswitchKeys.push_back(packingKeyswitchKey);
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
if (argPos >= inputs.size()) {
return StringError("allocate_lwe position of argument is too high");
}
auto inputSk = inputs[argPos];
auto encryption = std::get<0>(inputSk).encryption;
if (!encryption.has_value()) {
return StringError("allocate_lwe argument #")
<< argPos << "is not encypeted";
}
auto numBlocks =
encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size();
assert(inputSk.second.has_value());
size = inputSk.second->parameters().lweSize();
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks);
return outcome::success();
}
bool KeySet::isInputEncrypted(size_t argPos) {
return argPos < inputs.size() &&
std::get<0>(inputs[argPos]).encryption.has_value();
}
bool KeySet::isOutputEncrypted(size_t argPos) {
return argPos < outputs.size() &&
std::get<0>(outputs[argPos]).encryption.has_value();
}
/// Return the number of bits to represents the given value
uint64_t bitWidthOfValue(uint64_t value) { return std::ceil(std::log2(value)); }
outcome::checked<void, StringError>
KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
if (argPos >= inputs.size()) {
return StringError("encrypt_lwe position of argument is too high");
}
const auto &inputSk = inputs[argPos];
auto encryption = std::get<0>(inputSk).encryption;
if (!encryption.has_value()) {
return StringError("encrypt_lwe the positional argument is not encrypted");
}
auto encoding = encryption->encoding;
assert(inputSk.second.has_value());
auto lweSecretKey = *inputSk.second;
auto lweSecretKeyParam = lweSecretKey.parameters();
// CRT encoding - N blocks with crt encoding
auto crt = encryption->encoding.crt;
if (!crt.empty()) {
// Put each decomposition into a new ciphertext
auto product = crt::productOfModuli(crt);
for (auto modulus : crt) {
auto plaintext = crt::encode(input, modulus, product);
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
}
return outcome::success();
}
// Simple TFHE integers - 1 blocks with one padding bits
// TODO we could check if the input value is in the right range
uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1));
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
return outcome::success();
}
outcome::checked<void, StringError>
KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
if (argPos >= outputs.size()) {
return StringError("decrypt_lwe: position of argument is too high");
}
auto outputSk = outputs[argPos];
assert(outputSk.second.has_value());
auto lweSecretKey = *outputSk.second;
auto lweSecretKeyParam = lweSecretKey.parameters();
auto encryption = std::get<0>(outputSk).encryption;
if (!encryption.has_value()) {
return StringError("decrypt_lwe: the positional argument is not encrypted");
}
auto crt = encryption->encoding.crt;
if (!crt.empty()) {
// CRT encoded TFHE integers
// Decrypt and decode remainders
std::vector<int64_t> remainders;
for (auto modulus : crt) {
uint64_t decrypted = 0;
lweSecretKey.decrypt(ciphertext, decrypted);
auto plaintext = crt::decode(decrypted, modulus);
remainders.push_back(plaintext);
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
}
// Compute the inverse crt
output = crt::iCrt(crt, remainders);
// Further decode signed integers
if (encryption->encoding.isSigned) {
uint64_t maxPos = 1;
for (auto prime : encryption->encoding.crt) {
maxPos *= prime;
}
maxPos /= 2;
if (output >= maxPos) {
output -= maxPos * 2;
}
}
} else {
// Native encoded TFHE integers - 1 blocks with one padding bits
uint64_t plaintext = 0;
lweSecretKey.decrypt(ciphertext, plaintext);
// Decode unsigned integer
uint64_t precision = encryption->encoding.precision;
output = plaintext >> (64 - precision - 2);
auto carry = output % 2;
uint64_t mod = (((uint64_t)1) << (precision + 1));
output = ((output >> 1) + carry) % mod;
// Further decode signed integers.
if (encryption->encoding.isSigned) {
uint64_t maxPos = (((uint64_t)1) << (precision - 1));
if (output >= maxPos) { // The output is actually negative.
// Set the preceding bits to zero
output |= UINT64_MAX << precision;
// This makes sure when the value is cast to int64, it has the correct
// value
};
}
}
return outcome::success();
}
const std::vector<LweSecretKey> &KeySet::getSecretKeys() const {
return secretKeys;
}
const std::vector<LweBootstrapKey> &KeySet::getBootstrapKeys() const {
return bootstrapKeys;
}
const std::vector<LweKeyswitchKey> &KeySet::getKeyswitchKeys() const {
return keyswitchKeys;
}
const std::vector<PackingKeyswitchKey> &
KeySet::getPackingKeyswitchKeys() const {
return packingKeyswitchKeys;
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -1,292 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "boost/outcome.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include <fstream>
#include <sstream>
#include <string>
#include <utime.h>
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
inline void getApproval() {
std::cerr << "DANGER: You are using an empty unsecure secret keys. Enter "
"\"y\" to continue: ";
char answer;
std::cin >> answer;
if (answer != 'y') {
std::abort();
}
}
#endif
namespace concretelang {
namespace clientlib {
using StringError = concretelang::error::StringError;
template <class Key>
outcome::checked<Key, StringError> loadKey(llvm::SmallString<0> &path,
Key(deser)(std::istream &istream)) {
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
getApproval();
#endif
std::ifstream in((std::string)path, std::ofstream::binary);
if (in.fail()) {
return StringError("Cannot access " + (std::string)path);
}
auto key = deser(in);
if (in.bad()) {
return StringError("Cannot load key at path(") << (std::string)path << ")";
}
return key;
}
template <class Key>
outcome::checked<void, StringError> saveKey(llvm::SmallString<0> &path,
Key key) {
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
getApproval();
#endif
std::ofstream out((std::string)path, std::ofstream::binary);
if (out.fail()) {
return StringError("Cannot access " + (std::string)path);
}
out << key;
if (out.bad()) {
return StringError("Cannot save key at path(") << (std::string)path << ")";
}
out.close();
return outcome::success();
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::loadKeys(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb, std::string folderPath) {
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
getApproval();
#endif
// Mark the folder as recently use.
// e.g. so the CI can do some cleanup of unused keys.
utime(folderPath.c_str(), nullptr);
std::vector<LweSecretKey> secretKeys;
std::vector<LweBootstrapKey> bootstrapKeys;
std::vector<LweKeyswitchKey> keyswitchKeys;
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
// Load secret keys
for (auto p : llvm::enumerate(params.secretKeys)) {
// TODO - Check parameters?
// auto param = secretKeyParam.second;
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index()));
OUTCOME_TRY(auto key, loadKey(path, readLweSecretKey));
secretKeys.push_back(key);
}
// Load bootstrap keys
for (auto p : llvm::enumerate(params.bootstrapKeys)) {
// TODO - Check parameters?
// auto param = p.value();
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index()));
OUTCOME_TRY(auto key, loadKey(path, readLweBootstrapKey));
bootstrapKeys.push_back(key);
}
// Load keyswitch keys
for (auto p : llvm::enumerate(params.keyswitchKeys)) {
// TODO - Check parameters?
// auto param = p.value();
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index()));
OUTCOME_TRY(auto key, loadKey(path, readLweKeyswitchKey));
keyswitchKeys.push_back(key);
}
for (auto p : llvm::enumerate(params.packingKeyswitchKeys)) {
// TODO - Check parameters?
// auto param = p.value();
llvm::SmallString<0> path(folderPath);
llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index()));
OUTCOME_TRY(auto key, loadKey(path, readPackingKeyswitchKey));
packingKeyswitchKeys.push_back(key);
}
__uint128_t seed = seed_msb;
seed <<= 64;
seed += seed_lsb;
auto csprng = ConcreteCSPRNG(seed);
OUTCOME_TRY(auto keySet,
KeySet::fromKeys(params, secretKeys, bootstrapKeys, keyswitchKeys,
packingKeyswitchKeys, std::move(csprng)));
return std::move(keySet);
}
outcome::checked<void, StringError> saveKeys(KeySet &key_set,
llvm::SmallString<0> &folderPath) {
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
getApproval();
#endif
llvm::SmallString<0> folderIncompletePath = folderPath;
folderIncompletePath.append(".incomplete");
auto err = llvm::sys::fs::create_directories(folderIncompletePath);
if (err) {
return StringError("Cannot create directory \"")
<< std::string(folderIncompletePath) << "\": " << err.message();
}
// Save LWE secret keys
for (auto p : llvm::enumerate(key_set.getSecretKeys())) {
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index()));
OUTCOME_TRYV(saveKey(path, p.value()));
}
// Save bootstrap keys
for (auto p : llvm::enumerate(key_set.getBootstrapKeys())) {
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index()));
OUTCOME_TRYV(saveKey(path, p.value()));
}
// Save keyswitch keys
for (auto p : llvm::enumerate(key_set.getKeyswitchKeys())) {
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index()));
OUTCOME_TRYV(saveKey(path, p.value()));
}
// Save packing keyswitch keys
for (auto p : llvm::enumerate(key_set.getPackingKeyswitchKeys())) {
llvm::SmallString<0> path = folderIncompletePath;
llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index()));
OUTCOME_TRYV(saveKey(path, p.value()));
}
err = llvm::sys::fs::rename(folderIncompletePath, folderPath);
if (err) {
llvm::sys::fs::remove_directories(folderIncompletePath);
}
if (!llvm::sys::fs::exists(folderPath)) {
return StringError("Cannot save directory \"")
<< std::string(folderPath) << "\"";
}
return outcome::success();
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::loadOrGenerateSave(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
getApproval();
#endif
llvm::SmallString<0> folderPath =
llvm::SmallString<0>(this->backingDirectoryPath);
llvm::sys::path::append(folderPath, std::to_string(params.hash()));
llvm::sys::path::append(folderPath, std::to_string(seed_msb) + "_" +
std::to_string(seed_lsb));
// Creating a lock for concurrent generation
llvm::SmallString<0> lockPath(folderPath);
lockPath.append("lock");
int FD_lock;
llvm::sys::fs::create_directories(llvm::sys::path::parent_path(lockPath));
// Open or create the lock file
auto err = llvm::sys::fs::openFile(
lockPath, FD_lock, llvm::sys::fs::CreationDisposition::CD_OpenAlways,
llvm::sys::fs::FileAccess::FA_Write, llvm::sys::fs::OpenFlags::OF_None);
if (err) {
// parent does not exists OR right issue (creation or write)
return StringError("Cannot access \"")
<< std::string(lockPath) << "\": " << err.message();
}
// The lock is released when the function returns.
// => any intermediate state in the function is not visible to others.
auto unlockAtReturn = llvm::make_scope_exit([&]() {
llvm::sys::fs::closeFile(FD_lock);
llvm::sys::fs::unlockFile(FD_lock);
llvm::sys::fs::remove(lockPath);
});
llvm::sys::fs::lockFile(FD_lock);
if (llvm::sys::fs::exists(folderPath)) {
// Once it has been generated by another process (or was alread here)
auto keys = loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
if (keys.has_value()) {
return keys;
} else {
std::cerr << std::string(keys.error().mesg) << "\n";
std::cerr << "Invalid KeySetCache entry " << std::string(folderPath)
<< "\n";
llvm::sys::fs::remove_directories(folderPath);
// Then we can continue as it didn't exist
}
}
std::cerr << "KeySetCache: miss, regenerating " << std::string(folderPath)
<< "\n";
__uint128_t seed = seed_msb;
seed <<= 64;
seed += seed_lsb;
auto csprng = ConcreteCSPRNG(seed);
OUTCOME_TRY(auto key_set, KeySet::generate(params, std::move(csprng)));
OUTCOME_TRYV(saveKeys(*key_set, folderPath));
return std::move(key_set);
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::generate(std::shared_ptr<KeySetCache> cache,
ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
getApproval();
#endif
__uint128_t seed = seed_msb;
seed <<= 64;
seed += seed_lsb;
auto csprng = ConcreteCSPRNG(seed);
return cache ? cache->loadOrGenerateSave(params, seed_msb, seed_lsb)
: KeySet::generate(params, std::move(csprng));
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::generate(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
getApproval();
#endif
return loadOrGenerateSave(params, seed_msb, seed_lsb);
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -1,172 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <iostream>
#include <stdlib.h>
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Serializers.h"
namespace concretelang {
namespace clientlib {
using concretelang::error::StringError;
// TODO: optimize the move
PublicArguments::PublicArguments(
const ClientParameters &clientParameters,
std::vector<clientlib::SharedScalarOrTensorData> &buffers)
: clientParameters(clientParameters) {
arguments = buffers;
}
PublicArguments::~PublicArguments() {}
outcome::checked<void, StringError>
PublicArguments::serialize(std::ostream &ostream) {
if (incorrectMode(ostream)) {
return StringError(
"PublicArguments::serialize: ostream should be in binary mode");
}
serializeVectorOfScalarOrTensorData(arguments, ostream);
if (ostream.bad()) {
return StringError(
"PublicArguments::serialize: cannot serialize public arguments");
}
return outcome::success();
}
outcome::checked<void, StringError>
PublicArguments::unserializeArgs(std::istream &istream) {
OUTCOME_TRY(arguments, unserializeVectorOfScalarOrTensorData(istream));
return outcome::success();
}
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
PublicArguments::unserialize(const ClientParameters &expectedParams,
std::istream &istream) {
std::vector<SharedScalarOrTensorData> emptyBuffers;
auto sArguments =
std::make_unique<PublicArguments>(expectedParams, emptyBuffers);
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
return std::move(sArguments);
}
outcome::checked<void, StringError>
PublicResult::unserialize(std::istream &istream) {
OUTCOME_TRY(buffers, unserializeVectorOfScalarOrTensorData(istream));
return outcome::success();
}
outcome::checked<void, StringError>
PublicResult::serialize(std::ostream &ostream) {
serializeVectorOfScalarOrTensorData(buffers, ostream);
if (ostream.bad()) {
return StringError("PublicResult::serialize: cannot serialize");
}
return outcome::success();
}
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
// increase multi dim index
for (int r = rank - 1; r >= 0; r--) {
if (index[r] < sizes[r] - 1) {
index[r]++;
return;
}
index[r] = 0;
}
}
size_t global_index(size_t index[], size_t sizes[], size_t strides[],
size_t rank) {
// compute global index from multi dim index
size_t g_index = 0;
size_t default_stride = 1;
for (int r = rank - 1; r >= 0; r--) {
g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]);
default_stride *= sizes[r];
}
return g_index;
}
static inline bool isReferenceToMLIRGlobalMemory(void *ptr) {
return reinterpret_cast<uintptr_t>(ptr) == 0xdeadbeef;
}
template <typename T>
TensorData tensorDataFromMemRefTyped(size_t memref_rank, void *allocatedVoid,
void *alignedVoid, size_t offset,
size_t *sizes, size_t *strides) {
T *allocated = reinterpret_cast<T *>(allocatedVoid);
T *aligned = reinterpret_cast<T *>(alignedVoid);
TensorData result(llvm::ArrayRef<size_t>{sizes, memref_rank}, sizeof(T) * 8,
std::is_signed<T>());
assert(aligned != nullptr);
// ephemeral multi dim index to compute global strides
size_t *index = new size_t[memref_rank];
for (size_t r = 0; r < memref_rank; r++) {
index[r] = 0;
}
auto len = result.length();
// TODO: add a fast path for dense result (no real strides)
for (size_t i = 0; i < len; i++) {
int g_index = offset + global_index(index, sizes, strides, memref_rank);
result.getElementReference<T>(i) = aligned[g_index];
next_coord_index(index, sizes, memref_rank);
}
delete[] index;
// TEMPORARY: That quick and dirty but as this function is used only to
// convert a result of the mlir program and as data are copied here, we
// release the alocated pointer if it set.
if (allocated != nullptr && !isReferenceToMLIRGlobalMemory(allocated)) {
free(allocated);
}
return result;
}
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
bool is_signed, void *allocated, void *aligned,
size_t offset, size_t *sizes, size_t *strides) {
ElementType et = getElementTypeFromWidthAndSign(element_width, is_signed);
switch (et) {
case ElementType::i64:
return tensorDataFromMemRefTyped<int64_t>(memref_rank, allocated, aligned,
offset, sizes, strides);
case ElementType::u64:
return tensorDataFromMemRefTyped<uint64_t>(memref_rank, allocated, aligned,
offset, sizes, strides);
case ElementType::i32:
return tensorDataFromMemRefTyped<int32_t>(memref_rank, allocated, aligned,
offset, sizes, strides);
case ElementType::u32:
return tensorDataFromMemRefTyped<uint32_t>(memref_rank, allocated, aligned,
offset, sizes, strides);
case ElementType::i16:
return tensorDataFromMemRefTyped<int16_t>(memref_rank, allocated, aligned,
offset, sizes, strides);
case ElementType::u16:
return tensorDataFromMemRefTyped<uint16_t>(memref_rank, allocated, aligned,
offset, sizes, strides);
case ElementType::i8:
return tensorDataFromMemRefTyped<int8_t>(memref_rank, allocated, aligned,
offset, sizes, strides);
case ElementType::u8:
return tensorDataFromMemRefTyped<uint8_t>(memref_rank, allocated, aligned,
offset, sizes, strides);
}
// Cannot happen
assert(false);
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -1,585 +0,0 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <iosfwd>
#include <iostream>
#include <stdlib.h>
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
namespace concretelang {
namespace clientlib {
template <typename Key>
std::ostream &writeUInt64KeyBuffer(std::ostream &ostream, Key &buffer) {
writeSize(ostream, (uint64_t)buffer.size());
ostream.write((const char *)buffer.buffer(),
buffer.size() * sizeof(uint64_t));
assert(ostream.good());
return ostream;
}
std::istream &operator>>(std::istream &istream,
std::shared_ptr<std::vector<uint64_t>> &vec) {
// TODO assertion on size?
uint64_t size;
readSize(istream, size);
vec->resize(size);
istream.read((char *)vec->data(), size * sizeof(uint64_t));
assert(istream.good());
return istream;
}
// LweSecretKey ////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const LweSecretKeyParam param) {
writeWord(ostream, param.dimension);
return ostream;
}
std::istream &operator>>(std::istream &istream, LweSecretKeyParam &param) {
readWord(istream, param.dimension);
return istream;
}
// LweSecretKey /////////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &key) {
ostream << key.parameters();
writeUInt64KeyBuffer(ostream, key);
return ostream;
}
LweSecretKey readLweSecretKey(std::istream &istream) {
LweSecretKeyParam param;
istream >> param;
auto buffer = std::make_shared<std::vector<uint64_t>>();
istream >> buffer;
return LweSecretKey(buffer, param);
}
// KeyswitchKeyParam ////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const KeyswitchKeyParam param) {
// TODO keys id
writeWord(ostream, param.level);
writeWord(ostream, param.baseLog);
writeWord(ostream, param.variance);
return ostream;
}
std::istream &operator>>(std::istream &istream, KeyswitchKeyParam &param) {
// TODO keys id
param.outputSecretKeyID = 1234;
param.inputSecretKeyID = 1234;
readWord(istream, param.level);
readWord(istream, param.baseLog);
readWord(istream, param.variance);
return istream;
}
// LweKeyswitchKey //////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &key) {
ostream << key.parameters();
writeUInt64KeyBuffer(ostream, key);
return ostream;
}
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream) {
KeyswitchKeyParam param;
istream >> param;
auto buffer = std::make_shared<std::vector<uint64_t>>();
istream >> buffer;
return LweKeyswitchKey(buffer, param);
}
// BootstrapKeyParam ////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const BootstrapKeyParam param) {
// TODO keys id
writeWord(ostream, param.level);
writeWord(ostream, param.baseLog);
writeWord(ostream, param.glweDimension);
writeWord(ostream, param.variance);
writeWord(ostream, param.polynomialSize);
writeWord(ostream, param.inputLweDimension);
return ostream;
}
std::istream &operator>>(std::istream &istream, BootstrapKeyParam &param) {
// TODO keys id
readWord(istream, param.level);
readWord(istream, param.baseLog);
readWord(istream, param.glweDimension);
readWord(istream, param.variance);
readWord(istream, param.polynomialSize);
readWord(istream, param.inputLweDimension);
return istream;
}
// LweBootstrapKey //////////////////////////////
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &key) {
ostream << key.parameters();
writeUInt64KeyBuffer(ostream, key);
return ostream;
}
LweBootstrapKey readLweBootstrapKey(std::istream &istream) {
BootstrapKeyParam param;
istream >> param;
auto buffer = std::make_shared<std::vector<uint64_t>>();
istream >> buffer;
return LweBootstrapKey(buffer, param);
}
// PackingKeyswitchKeyParam ////////////////////////////
std::ostream &operator<<(std::ostream &ostream,
const PackingKeyswitchKeyParam param) {
// TODO keys id
writeWord(ostream, param.level);
writeWord(ostream, param.baseLog);
writeWord(ostream, param.glweDimension);
writeWord(ostream, param.polynomialSize);
writeWord(ostream, param.inputLweDimension);
writeWord(ostream, param.variance);
return ostream;
}
std::istream &operator>>(std::istream &istream,
PackingKeyswitchKeyParam &param) {
// TODO keys id
param.outputSecretKeyID = 1234;
param.inputSecretKeyID = 1234;
readWord(istream, param.level);
readWord(istream, param.baseLog);
readWord(istream, param.glweDimension);
readWord(istream, param.polynomialSize);
readWord(istream, param.inputLweDimension);
readWord(istream, param.variance);
return istream;
}
// PackingKeyswitchKey //////////////////////////////
std::ostream &operator<<(std::ostream &ostream,
const PackingKeyswitchKey &key) {
ostream << key.parameters();
writeUInt64KeyBuffer(ostream, key);
return ostream;
}
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream) {
PackingKeyswitchKeyParam param;
istream >> param;
auto buffer = std::make_shared<std::vector<uint64_t>>();
istream >> buffer;
auto b = PackingKeyswitchKey(buffer, param);
return b;
}
// KeySet ////////////////////////////////
std::unique_ptr<KeySet> readKeySet(std::istream &istream) {
uint64_t nbKey;
readSize(istream, nbKey);
std::vector<LweSecretKey> secretKeys;
for (uint64_t i = 0; i < nbKey; i++) {
secretKeys.push_back(readLweSecretKey(istream));
}
readSize(istream, nbKey);
std::vector<LweBootstrapKey> bootstrapKeys;
for (uint64_t i = 0; i < nbKey; i++) {
bootstrapKeys.push_back(readLweBootstrapKey(istream));
}
readSize(istream, nbKey);
std::vector<LweKeyswitchKey> keyswitchKeys;
for (uint64_t i = 0; i < nbKey; i++) {
keyswitchKeys.push_back(readLweKeyswitchKey(istream));
}
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
readSize(istream, nbKey);
for (uint64_t i = 0; i < nbKey; i++) {
packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream));
}
std::string clientParametersString;
istream >> clientParametersString;
auto clientParameters =
llvm::json::parse<ClientParameters>(clientParametersString);
if (!clientParameters) {
return std::unique_ptr<KeySet>(nullptr);
}
auto csprng = ConcreteCSPRNG(0);
auto keySet =
KeySet::fromKeys(clientParameters.get(), secretKeys, bootstrapKeys,
keyswitchKeys, packingKeyswitchKeys, std::move(csprng));
return std::move(keySet.value());
}
std::ostream &operator<<(std::ostream &ostream, const KeySet &keySet) {
auto secretKeys = keySet.getSecretKeys();
writeSize(ostream, secretKeys.size());
for (auto sk : secretKeys) {
ostream << sk;
}
auto bootstrapKeys = keySet.getBootstrapKeys();
writeSize(ostream, bootstrapKeys.size());
for (auto bsk : bootstrapKeys) {
ostream << bsk;
}
auto keyswitchKeys = keySet.getKeyswitchKeys();
writeSize(ostream, keyswitchKeys.size());
for (auto ksk : keyswitchKeys) {
ostream << ksk;
}
auto packingKeyswitchKeys = keySet.getPackingKeyswitchKeys();
writeSize(ostream, packingKeyswitchKeys.size());
for (auto pksk : packingKeyswitchKeys) {
ostream << pksk;
}
auto clientParametersJson = llvm::json::Value(keySet.clientParameters());
std::string clientParametersString;
llvm::raw_string_ostream clientParametersStringBuffer(clientParametersString);
clientParametersStringBuffer << clientParametersJson;
ostream << clientParametersString;
assert(ostream.good());
return ostream;
}
// EvaluationKey ////////////////////////////////
EvaluationKeys readEvaluationKeys(std::istream &istream) {
uint64_t nbKey;
readSize(istream, nbKey);
std::vector<LweBootstrapKey> bootstrapKeys;
for (uint64_t i = 0; i < nbKey; i++) {
bootstrapKeys.push_back(readLweBootstrapKey(istream));
}
readSize(istream, nbKey);
std::vector<LweKeyswitchKey> keyswitchKeys;
for (uint64_t i = 0; i < nbKey; i++) {
keyswitchKeys.push_back(readLweKeyswitchKey(istream));
}
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
readSize(istream, nbKey);
for (uint64_t i = 0; i < nbKey; i++) {
packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream));
}
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
}
std::ostream &operator<<(std::ostream &ostream,
const EvaluationKeys &evaluationKeys) {
auto bootstrapKeys = evaluationKeys.getBootstrapKeys();
writeSize(ostream, bootstrapKeys.size());
for (auto bsk : bootstrapKeys) {
ostream << bsk;
}
auto keyswitchKeys = evaluationKeys.getKeyswitchKeys();
writeSize(ostream, keyswitchKeys.size());
for (auto ksk : keyswitchKeys) {
ostream << ksk;
}
auto packingKeyswitchKeys = evaluationKeys.getPackingKeyswitchKeys();
writeSize(ostream, packingKeyswitchKeys.size());
for (auto pksk : packingKeyswitchKeys) {
ostream << pksk;
}
assert(ostream.good());
return ostream;
}
// TensorData ///////////////////////////////////
template <typename T>
std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) {
writeWord<uint64_t>(ostream, sizeof(T) * 8);
writeWord<uint8_t>(ostream, std::is_signed<T>());
writeWord<T>(ostream, value);
return ostream;
}
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream) {
switch (sd.getType()) {
case ElementType::u64:
return serializeScalarDataRaw<uint64_t>(sd.getValue<uint64_t>(), ostream);
case ElementType::i64:
return serializeScalarDataRaw<int64_t>(sd.getValue<int64_t>(), ostream);
case ElementType::u32:
return serializeScalarDataRaw<uint32_t>(sd.getValue<uint32_t>(), ostream);
case ElementType::i32:
return serializeScalarDataRaw<int32_t>(sd.getValue<int32_t>(), ostream);
case ElementType::u16:
return serializeScalarDataRaw<uint16_t>(sd.getValue<uint16_t>(), ostream);
case ElementType::i16:
return serializeScalarDataRaw<int16_t>(sd.getValue<int16_t>(), ostream);
case ElementType::u8:
return serializeScalarDataRaw<uint8_t>(sd.getValue<uint8_t>(), ostream);
case ElementType::i8:
return serializeScalarDataRaw<int8_t>(sd.getValue<int8_t>(), ostream);
}
return ostream;
}
template <typename T> ScalarData unserializeScalarValue(std::istream &istream) {
T value;
readWord(istream, value);
return ScalarData(value);
}
outcome::checked<ScalarData, StringError>
unserializeScalarData(std::istream &istream) {
uint64_t scalarWidth;
readWord(istream, scalarWidth);
switch (scalarWidth) {
case 64:
case 32:
case 16:
case 8:
break;
default:
return StringError("Scalar width must be either 64, 32, 16 or 8, but got ")
<< scalarWidth;
}
uint8_t scalarSignedness;
readWord(istream, scalarSignedness);
if (scalarSignedness != 0 && scalarSignedness != 1) {
return StringError("Numerical value for scalar signedness must be either "
"0 or 1, but got ")
<< scalarSignedness;
}
switch (scalarWidth) {
case 64:
return (scalarSignedness) ? unserializeScalarValue<int64_t>(istream)
: unserializeScalarValue<uint64_t>(istream);
case 32:
return (scalarSignedness) ? unserializeScalarValue<int32_t>(istream)
: unserializeScalarValue<uint32_t>(istream);
case 16:
return (scalarSignedness) ? unserializeScalarValue<int16_t>(istream)
: unserializeScalarValue<uint16_t>(istream);
case 8:
return (scalarSignedness) ? unserializeScalarValue<int8_t>(istream)
: unserializeScalarValue<uint8_t>(istream);
}
assert(false && "Unhandled scalar type");
}
template <typename T>
static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes,
std::istream &istream) {
// getElementPointer is not valid if the tensor contains no data
if (values_and_sizes.getNumElements() > 0) {
readWords(istream, values_and_sizes.getElementPointer<T>(0),
values_and_sizes.getNumElements());
}
return istream;
}
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
std::ostream &ostream) {
switch (values_and_sizes.getElementType()) {
case ElementType::u64:
return serializeTensorDataRaw<uint64_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint64_t>(), ostream);
case ElementType::i64:
return serializeTensorDataRaw<int64_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int64_t>(), ostream);
case ElementType::u32:
return serializeTensorDataRaw<uint32_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint32_t>(), ostream);
case ElementType::i32:
return serializeTensorDataRaw<int32_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int32_t>(), ostream);
case ElementType::u16:
return serializeTensorDataRaw<uint16_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint16_t>(), ostream);
case ElementType::i16:
return serializeTensorDataRaw<int16_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int16_t>(), ostream);
case ElementType::u8:
return serializeTensorDataRaw<uint8_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint8_t>(), ostream);
case ElementType::i8:
return serializeTensorDataRaw<int8_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int8_t>(), ostream);
}
assert(false && "Unhandled element type");
}
outcome::checked<TensorData, StringError>
unserializeTensorData(std::istream &istream) {
if (incorrectMode(istream)) {
return StringError("Stream is in incorrect mode");
}
uint64_t numDimensions;
readWord(istream, numDimensions);
std::vector<size_t> dims;
for (uint64_t i = 0; i < numDimensions; i++) {
int64_t dimSize;
readWord(istream, dimSize);
dims.push_back(dimSize);
}
uint64_t elementWidth;
readWord(istream, elementWidth);
switch (elementWidth) {
case 64:
case 32:
case 16:
case 8:
break;
default:
return StringError("Element width must be either 64, 32, 16 or 8, but got ")
<< elementWidth;
}
uint8_t elementSignedness;
readWord(istream, elementSignedness);
if (elementSignedness != 0 && elementSignedness != 1) {
return StringError("Numerical value for element signedness must be either "
"0 or 1, but got ")
<< elementSignedness;
}
TensorData result(dims, elementWidth, elementSignedness == 1);
switch (result.getElementType()) {
case ElementType::u64:
unserializeTensorDataElements<uint64_t>(result, istream);
break;
case ElementType::i64:
unserializeTensorDataElements<int64_t>(result, istream);
break;
case ElementType::u32:
unserializeTensorDataElements<uint32_t>(result, istream);
break;
case ElementType::i32:
unserializeTensorDataElements<int32_t>(result, istream);
break;
case ElementType::u16:
unserializeTensorDataElements<uint16_t>(result, istream);
break;
case ElementType::i16:
unserializeTensorDataElements<int16_t>(result, istream);
break;
case ElementType::u8:
unserializeTensorDataElements<uint8_t>(result, istream);
break;
case ElementType::i8:
unserializeTensorDataElements<int8_t>(result, istream);
break;
}
return std::move(result);
}
std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
std::ostream &ostream) {
writeWord<uint8_t>(ostream, sotd.isTensor());
if (sotd.isTensor())
return serializeTensorData(sotd.getTensor(), ostream);
else
return serializeScalarData(sotd.getScalar(), ostream);
}
outcome::checked<ScalarOrTensorData, StringError>
unserializeScalarOrTensorData(std::istream &istream) {
uint8_t isTensor;
readWord(istream, isTensor);
if (isTensor != 0 && isTensor != 1) {
return StringError("Numerical value indicating whether a data element is a "
"tensor must be either 0 or 1, but got ")
<< isTensor;
}
if (isTensor) {
auto tdOrErr = unserializeTensorData(istream);
if (tdOrErr.has_error())
return std::move(tdOrErr.error());
else
return ScalarOrTensorData(std::move(tdOrErr.value()));
} else {
auto tdOrErr = unserializeScalarData(istream);
if (tdOrErr.has_error())
return std::move(tdOrErr.error());
else
return ScalarOrTensorData(std::move(tdOrErr.value()));
}
}
std::ostream &serializeVectorOfScalarOrTensorData(
const std::vector<SharedScalarOrTensorData> &v, std::ostream &ostream) {
writeSize(ostream, v.size());
for (auto &sotd : v) {
serializeScalarOrTensorData(sotd.get(), ostream);
if (!ostream.good()) {
return ostream;
}
}
return ostream;
}
outcome::checked<std::vector<SharedScalarOrTensorData>, StringError>
unserializeVectorOfScalarOrTensorData(std::istream &istream) {
uint64_t nbElt;
readSize(istream, nbElt);
std::vector<SharedScalarOrTensorData> v;
for (uint64_t i = 0; i < nbElt; i++) {
OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream));
v.push_back(SharedScalarOrTensorData(std::move(elt)));
}
return v;
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -0,0 +1,20 @@
add_compile_options(-fexceptions -fsized-deallocation -fno-rtti)
add_mlir_library(
ConcretelangCommon
Protocol.cpp
CRT.cpp
Csprng.cpp
Keys.cpp
Keysets.cpp
Transformers.cpp
Values.cpp
DEPENDS
concrete-protocol
LINK_LIBS
PUBLIC
concrete_cpu
kj
capnp)
target_include_directories(ConcretelangCommon PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})

Some files were not shown because too many files have changed in this diff Show More