mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-18 00:21:36 -05:00
feat(compiler): introduce concrete-protocol
This commit: + Adds support for a protocol which enables inter-op between concrete, tfhe-rs and potentially other contributors to the fhe ecosystem. + Gets rid of hand-made serialization in the compiler, and client/server libs. + Refactors client/server libs to allow more pre/post processing of circuit inputs/outputs. The protocol is supported by a definition in the shape of a capnp file, which defines different types of objects among which: + ProgramInfo object, which is a precise description of a set of fhe circuit coming from the same compilation (understand function type information), and the associated key set. + *Key objects, which represent secret/public keys used to encrypt/execute fhe circuits. + Value object, which represent values that can be transferred between client and server to support calls to fhe circuits. The hand-rolled serialization that was previously used is completely dropped in favor of capnp in the whole codebase. The client/server libs, are refactored to introduce a modular design for pre-post processing. Reading the ProgramInfo file associated with a compilation, the client and server libs assemble a pipeline of transformers (functions) for pre and post processing of values coming in and out of a circuit. This design properly decouples various aspects of the processing, and allows these capabilities to be safely extended. In practice this commit includes the following: + Defines the specification in a concreteprotocol package + Integrate the compilation of this package as a compiler dependency via cmake + Modify the compiler to use the Encodings objects defined in the protocol + Modify the compiler to emit ProgramInfo files as compilation artifact, and gets rid of the bloated ClientParameters. + Introduces a new Common library containing the functionalities shared between the compiler and the client/server libs. + Introduces a functional pre-post processing pipeline to this common library + Modify the client/server libs to support loading ProgramInfo objects, and calling circuits using Value messages. + Drops support of JIT. + Drops support of C-api. + Drops support of Rust bindings. Co-authored-by: Nikita Frolov <nf@mkmks.org>
This commit is contained in:
committed by
Alexandre Péré
parent
9139101cc3
commit
e8ef48ffd8
2
compilers/concrete-compiler/.gitignore
vendored
2
compilers/concrete-compiler/.gitignore
vendored
@@ -48,5 +48,3 @@ _build/
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
compiler/lib/Bindings/Rust/target/
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Build dirs
|
||||
build*/
|
||||
.cache/
|
||||
|
||||
*.mlir.script
|
||||
*.lit_test_times.txt
|
||||
|
||||
@@ -3,13 +3,14 @@ cmake_minimum_required(VERSION 3.17)
|
||||
project(concretecompiler LANGUAGES C CXX)
|
||||
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set_property(GLOBAL PROPERTY GLOBAL_DEPENDS_DEBUG_MODE 0)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
# Needed on linux with clang 15 and on MacOS because cxx emits dollars in the optimizer C++ API
|
||||
add_definitions("-Wno-dollar-in-identifier-extension")
|
||||
|
||||
add_definitions("-Wno-c++98-compat-extra-semi")
|
||||
add_definitions("-Wall ")
|
||||
add_definitions("-Werror ")
|
||||
add_definitions("-Wfatal-errors")
|
||||
@@ -66,6 +67,19 @@ set(CONCRETELANG_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
|
||||
# -------------------------------------------------------------------------------
|
||||
include_directories(${PROJECT_SOURCE_DIR}/../../../tools/parameter-curves/concrete-security-curves-cpp/include)
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Concrete Protocol
|
||||
# -------------------------------------------------------------------------------
|
||||
set(CONCRETE_PROTOCOL_DIR "${PROJECT_SOURCE_DIR}/../../../tools/concrete-protocol")
|
||||
add_subdirectory(${CONCRETE_PROTOCOL_DIR} concrete-protocol)
|
||||
get_target_property(CONCRETE_PROTOCOL_GEN_DIR concrete-protocol BINARY_DIR)
|
||||
set(CONCRETE_PROTOCOL_CAPNP_SRC "${CONCRETE_PROTOCOL_GEN_DIR}/capnp_src_dir/c++/src")
|
||||
include_directories(${CONCRETE_PROTOCOL_GEN_DIR})
|
||||
include_directories(${CONCRETE_PROTOCOL_CAPNP_SRC})
|
||||
add_dependencies(mlir-headers concrete-protocol)
|
||||
install(TARGETS concrete-protocol EXPORT concrete-protocol)
|
||||
install(EXPORT concrete-protocol DESTINATION "./")
|
||||
|
||||
# -------------------------------------------------------------------------------
|
||||
# Concrete Optimizer
|
||||
# -------------------------------------------------------------------------------
|
||||
|
||||
@@ -110,7 +110,7 @@ else
|
||||
PYTHON_TESTS_MARKER="not parallel"
|
||||
endif
|
||||
|
||||
all: concretecompiler python-bindings build-tests build-benchmarks build-mlbench doc rust-bindings
|
||||
all: concretecompiler python-bindings build-tests build-benchmarks build-mlbench doc
|
||||
|
||||
# HPX #####################################################
|
||||
|
||||
@@ -174,14 +174,6 @@ python-bindings: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangPythonModules
|
||||
|
||||
rust-bindings: install
|
||||
cd lib/Bindings/Rust && \
|
||||
CONCRETE_COMPILER_INSTALL_DIR=$(INSTALL_PATH) \
|
||||
cargo build --release
|
||||
|
||||
CAPI:
|
||||
cmake --build $(BUILD_DIR) --target CONCRETELANGCAPIFHE CONCRETELANGCAPIFHELINALG CONCRETELANGCAPISupport
|
||||
|
||||
clientlib: build-initialized
|
||||
cmake --build $(BUILD_DIR) --target ConcretelangClientLib
|
||||
|
||||
@@ -249,14 +241,6 @@ run-python-tests: python-bindings concretecompiler
|
||||
test-compiler-file-output: concretecompiler
|
||||
pytest -vs tests/test_compiler_file_output
|
||||
|
||||
|
||||
## rust-tests
|
||||
run-rust-tests: rust-bindings
|
||||
cd lib/Bindings/Rust && \
|
||||
CONCRETE_COMPILER_INSTALL_DIR=$(INSTALL_PATH) \
|
||||
LD_LIBRARY_PATH=$(INSTALL_PATH)/lib \
|
||||
cargo test --release
|
||||
|
||||
## end-to-end-tests
|
||||
|
||||
build-end-to-end-jit-chunked-int: build-initialized
|
||||
@@ -307,7 +291,7 @@ run-end-to-end-tests: $(GTEST_PARALLEL_PY) build-end-to-end-tests generate-cpu-t
|
||||
$(foreach optimizer_strategy,$(OPTIMIZATION_STRATEGY_TO_TEST), $(foreach security,$(SECURITY_TO_TEST), \
|
||||
$(GTEST_PARALLEL_CMD) $(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_test \
|
||||
$(GTEST_PARALLEL_SEPARATOR) --backend=cpu --security-level=$(security) \
|
||||
--optimizer-strategy=$(optimizer_strategy) --jit $(FIXTURE_CPU_DIR)/*.yaml || exit $$?;))
|
||||
--optimizer-strategy=$(optimizer_strategy) $(FIXTURE_CPU_DIR)/*.yaml || exit $$?;))
|
||||
|
||||
### end-to-end-tests GPU
|
||||
|
||||
@@ -327,7 +311,7 @@ generate-gpu-tests: $(FIXTURE_GPU_DIR) $(FIXTURE_GPU_DIR)/end_to_end_apply_looku
|
||||
|
||||
run-end-to-end-tests-gpu: build-end-to-end-test generate-gpu-tests
|
||||
$(BUILD_DIR)/tools/concretelang/tests/end_to_end_tests/end_to_end_test \
|
||||
--backend=gpu --library /tmp/concrete_compiler/gpu_tests/ \
|
||||
--backend=gpu \
|
||||
$(FIXTURE_GPU_DIR)/*.yaml
|
||||
|
||||
## end-to-end-dataflow-tests
|
||||
@@ -477,12 +461,6 @@ python-format:
|
||||
python-lint:
|
||||
pylint --rcfile=../pylintrc lib/Bindings/Python/concrete/compiler
|
||||
|
||||
check-rust-format:
|
||||
cd lib/Bindings/Rust && cargo fmt --check
|
||||
|
||||
rust-format:
|
||||
cd lib/Bindings/Rust && cargo fmt
|
||||
|
||||
# libraries we want to have in the installation that aren't already a deps of other targets
|
||||
install-deps:
|
||||
cmake --build $(BUILD_DIR) --target MLIRCAPIRegisterEverything
|
||||
@@ -529,7 +507,7 @@ darwin-python-package:
|
||||
python-package: python-bindings $(OS)-python-package
|
||||
@echo The python package is: $(BUILD_DIR)/wheels/*.whl
|
||||
|
||||
install: concretecompiler CAPI install-deps
|
||||
install: concretecompiler install-deps
|
||||
$(info Install prefix set to $(INSTALL_PREFIX))
|
||||
$(info Installing under $(INSTALL_PATH))
|
||||
mkdir -p $(INSTALL_PATH)/include
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_DIALECT_FHE_H
|
||||
#define CONCRETELANG_C_DIALECT_FHE_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// \brief structure to return an MlirType or report that there was an error
|
||||
/// during type creation.
|
||||
typedef struct {
|
||||
MlirType type;
|
||||
bool isError;
|
||||
} MlirTypeOrError;
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHE, fhe);
|
||||
|
||||
/// Creates an encrypted integer type of `width` bits
|
||||
MLIR_CAPI_EXPORTED MlirTypeOrError
|
||||
fheEncryptedIntegerTypeGetChecked(MlirContext context, unsigned width);
|
||||
|
||||
/// If the type is an EncryptedInteger
|
||||
MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedIntegerType(MlirType);
|
||||
|
||||
/// Creates an encrypted signed integer type of `width` bits
|
||||
MLIR_CAPI_EXPORTED MlirTypeOrError
|
||||
fheEncryptedSignedIntegerTypeGetChecked(MlirContext context, unsigned width);
|
||||
|
||||
/// If the type is an EncryptedSignedInteger
|
||||
MLIR_CAPI_EXPORTED bool fheTypeIsAnEncryptedSignedIntegerType(MlirType);
|
||||
|
||||
/// \brief Get bitwidth of the encrypted integer type.
|
||||
///
|
||||
/// \return bitwidth of the encrypted integer or 0 if it's not an encrypted
|
||||
/// integer
|
||||
MLIR_CAPI_EXPORTED unsigned fheTypeIntegerWidthGet(MlirType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_DIALECT_FHE_H
|
||||
@@ -1,21 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
#define CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_DIALECT_FHELINALG_H
|
||||
@@ -1,21 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_DIALECT_TRACING_H
|
||||
#define CONCRETELANG_C_DIALECT_TRACING_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TRACING, tracing);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_DIALECT_TRACING_H
|
||||
@@ -1,406 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
#define CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/// The CAPI should be really careful about memory allocation. Every pointer
|
||||
/// returned should points to a new buffer allocated for the purpose of the
|
||||
/// CAPI, and should have a respective destructor function.
|
||||
|
||||
/// Opaque type declarations. Inspired from
|
||||
/// llvm-project/mlir/include/mlir-c/IR.h
|
||||
///
|
||||
/// Adds an error pointer to an allocated buffer holding the error message if
|
||||
/// any.
|
||||
#define DEFINE_C_API_STRUCT(name, storage) \
|
||||
struct name { \
|
||||
storage *ptr; \
|
||||
const char *error; \
|
||||
}; \
|
||||
typedef struct name name
|
||||
|
||||
DEFINE_C_API_STRUCT(CompilerEngine, void);
|
||||
DEFINE_C_API_STRUCT(CompilationContext, void);
|
||||
DEFINE_C_API_STRUCT(CompilationResult, void);
|
||||
DEFINE_C_API_STRUCT(Library, void);
|
||||
DEFINE_C_API_STRUCT(LibraryCompilationResult, void);
|
||||
DEFINE_C_API_STRUCT(LibrarySupport, void);
|
||||
DEFINE_C_API_STRUCT(CompilationOptions, void);
|
||||
DEFINE_C_API_STRUCT(OptimizerConfig, void);
|
||||
DEFINE_C_API_STRUCT(ServerLambda, void);
|
||||
DEFINE_C_API_STRUCT(Encoding, void);
|
||||
DEFINE_C_API_STRUCT(EncryptionGate, void);
|
||||
DEFINE_C_API_STRUCT(CircuitGate, void);
|
||||
DEFINE_C_API_STRUCT(ClientParameters, void);
|
||||
DEFINE_C_API_STRUCT(KeySet, void);
|
||||
DEFINE_C_API_STRUCT(KeySetCache, void);
|
||||
DEFINE_C_API_STRUCT(EvaluationKeys, void);
|
||||
DEFINE_C_API_STRUCT(LambdaArgument, void);
|
||||
DEFINE_C_API_STRUCT(PublicArguments, void);
|
||||
DEFINE_C_API_STRUCT(PublicResult, void);
|
||||
DEFINE_C_API_STRUCT(CompilationFeedback, void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
|
||||
/// NULL Pointer checkers. Generate functions to check if the struct contains a
|
||||
/// null pointer.
|
||||
#define DEFINE_NULL_PTR_CHECKER(funcname, storage) \
|
||||
bool funcname(storage s) { return s.ptr == NULL; }
|
||||
|
||||
DEFINE_NULL_PTR_CHECKER(compilerEngineIsNull, CompilerEngine)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationContextIsNull, CompilationContext)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationResultIsNull, CompilationResult)
|
||||
DEFINE_NULL_PTR_CHECKER(libraryIsNull, Library)
|
||||
DEFINE_NULL_PTR_CHECKER(libraryCompilationResultIsNull,
|
||||
LibraryCompilationResult)
|
||||
DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions)
|
||||
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig)
|
||||
DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda)
|
||||
DEFINE_NULL_PTR_CHECKER(circuitGateIsNull, CircuitGate)
|
||||
DEFINE_NULL_PTR_CHECKER(encodingIsNull, Encoding)
|
||||
DEFINE_NULL_PTR_CHECKER(encryptionGateIsNull, EncryptionGate)
|
||||
DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters)
|
||||
DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet)
|
||||
DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache)
|
||||
DEFINE_NULL_PTR_CHECKER(evaluationKeysIsNull, EvaluationKeys)
|
||||
DEFINE_NULL_PTR_CHECKER(lambdaArgumentIsNull, LambdaArgument)
|
||||
DEFINE_NULL_PTR_CHECKER(publicArgumentsIsNull, PublicArguments)
|
||||
DEFINE_NULL_PTR_CHECKER(publicResultIsNull, PublicResult)
|
||||
DEFINE_NULL_PTR_CHECKER(compilationFeedbackIsNull, CompilationFeedback)
|
||||
|
||||
#undef DEFINE_NULL_PTR_CHECKER
|
||||
|
||||
/// Each struct has a creator function that allocates memory for the underlying
|
||||
/// Cpp object referenced, and a destroy function that does free this allocated
|
||||
/// memory.
|
||||
|
||||
/// ********** Utilities *******************************************************
|
||||
|
||||
/// Destroy string references created by the compiler.
|
||||
///
|
||||
/// This is not supposed to destroy any string ref, but only the ones we have
|
||||
/// allocated memory for and know how to free.
|
||||
MLIR_CAPI_EXPORTED void mlirStringRefDestroy(MlirStringRef str);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool mlirStringRefIsNull(MlirStringRef str) {
|
||||
return str.data == NULL;
|
||||
}
|
||||
|
||||
/// ********** BufferRef CAPI **************************************************
|
||||
|
||||
/// A struct for binary buffers.
|
||||
///
|
||||
/// Contraty to MlirStringRef, it doesn't assume the pointer point to a null
|
||||
/// terminated string and the data should be considered as is in binary form.
|
||||
/// Useful for serialized objects.
|
||||
typedef struct BufferRef {
|
||||
const char *data;
|
||||
size_t length;
|
||||
const char *error;
|
||||
} BufferRef;
|
||||
|
||||
MLIR_CAPI_EXPORTED void bufferRefDestroy(BufferRef buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool bufferRefIsNull(BufferRef buffer) {
|
||||
return buffer.data == NULL;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef bufferRefCreate(const char *buffer, size_t length);
|
||||
|
||||
/// ********** CompilationTarget CAPI ******************************************
|
||||
|
||||
enum CompilationTarget {
|
||||
ROUND_TRIP,
|
||||
FHE,
|
||||
TFHE,
|
||||
PARAMETRIZED_TFHE,
|
||||
NORMALIZED_TFHE,
|
||||
BATCHED_TFHE,
|
||||
CONCRETE,
|
||||
STD,
|
||||
LLVM,
|
||||
LLVM_IR,
|
||||
OPTIMIZED_LLVM_IR,
|
||||
LIBRARY
|
||||
};
|
||||
typedef enum CompilationTarget CompilationTarget;
|
||||
|
||||
/// ********** CompilationOptions CAPI *****************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate(
|
||||
MlirStringRef funcName, bool autoParallelize, bool batchTFHEOps,
|
||||
bool dataflowParallelize, bool emitGPUOps, bool loopParallelize,
|
||||
bool optimizeTFHE, OptimizerConfig optimizerConfig, bool verifyDiagnostics);
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreateDefault();
|
||||
|
||||
MLIR_CAPI_EXPORTED void compilationOptionsDestroy(CompilationOptions options);
|
||||
|
||||
/// ********** OptimizerConfig CAPI ********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED OptimizerConfig
|
||||
optimizerConfigCreate(bool display, double fallback_log_norm_woppbs,
|
||||
double global_p_error, double p_error, uint64_t security,
|
||||
bool strategy_v0, bool use_gpu_constraints,
|
||||
uint32_t ciphertext_modulus_log, uint32_t fft_precision);
|
||||
|
||||
MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreateDefault();
|
||||
|
||||
MLIR_CAPI_EXPORTED void optimizerConfigDestroy(OptimizerConfig config);
|
||||
|
||||
/// ********** CompilerEngine CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilerEngine compilerEngineCreate();
|
||||
|
||||
MLIR_CAPI_EXPORTED void compilerEngineDestroy(CompilerEngine engine);
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilationResult compilerEngineCompile(
|
||||
CompilerEngine engine, MlirStringRef module, CompilationTarget target);
|
||||
|
||||
MLIR_CAPI_EXPORTED void
|
||||
compilerEngineCompileSetOptions(CompilerEngine engine,
|
||||
CompilationOptions options);
|
||||
|
||||
/// ********** CompilationResult CAPI ******************************************
|
||||
|
||||
/// Get a string reference holding the textual representation of the compiled
|
||||
/// module. The returned `MlirStringRef` should be destroyed using
|
||||
/// `mlirStringRefDestroy` to free memory.
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
compilationResultGetModuleString(CompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED void compilationResultDestroy(CompilationResult result);
|
||||
|
||||
/// ********** Library CAPI ****************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED Library libraryCreate(MlirStringRef outputDirPath,
|
||||
MlirStringRef runtimeLibraryPath,
|
||||
bool cleanUp);
|
||||
|
||||
MLIR_CAPI_EXPORTED void libraryDestroy(Library lib);
|
||||
|
||||
/// ********** LibraryCompilationResult CAPI ***********************************
|
||||
|
||||
MLIR_CAPI_EXPORTED void
|
||||
libraryCompilationResultDestroy(LibraryCompilationResult result);
|
||||
|
||||
/// ********** LibrarySupport CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED LibrarySupport
|
||||
librarySupportCreate(MlirStringRef outputDirPath,
|
||||
MlirStringRef runtimeLibraryPath, bool generateSharedLib,
|
||||
bool generateStaticLib, bool generateClientParameters,
|
||||
bool generateCompilationFeedback, bool generateCppHeader);
|
||||
|
||||
MLIR_CAPI_EXPORTED LibrarySupport librarySupportCreateDefault(
|
||||
MlirStringRef outputDirPath, MlirStringRef runtimeLibraryPath) {
|
||||
return librarySupportCreate(outputDirPath, runtimeLibraryPath, true, true,
|
||||
true, true, true);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile(
|
||||
LibrarySupport support, MlirStringRef module, CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED ServerLambda librarySupportLoadServerLambda(
|
||||
LibrarySupport support, LibraryCompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED ClientParameters librarySupportLoadClientParameters(
|
||||
LibrarySupport support, LibraryCompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED LibraryCompilationResult
|
||||
librarySupportLoadCompilationResult(LibrarySupport support);
|
||||
|
||||
MLIR_CAPI_EXPORTED CompilationFeedback librarySupportLoadCompilationFeedback(
|
||||
LibrarySupport support, LibraryCompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicResult
|
||||
librarySupportServerCall(LibrarySupport support, ServerLambda server,
|
||||
PublicArguments args, EvaluationKeys evalKeys);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
librarySupportGetSharedLibPath(LibrarySupport support);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirStringRef
|
||||
librarySupportGetClientParametersPath(LibrarySupport support);
|
||||
|
||||
MLIR_CAPI_EXPORTED void librarySupportDestroy(LibrarySupport support);
|
||||
|
||||
/// ********** ServerLamda CAPI ************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server);
|
||||
|
||||
/// ********** ClientParameters CAPI *******************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef clientParametersSerialize(ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED ClientParameters
|
||||
clientParametersUnserialize(BufferRef buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED ClientParameters
|
||||
clientParametersCopy(ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params);
|
||||
|
||||
/// Returns the number of output circuit gates
|
||||
MLIR_CAPI_EXPORTED size_t clientParametersOutputsSize(ClientParameters params);
|
||||
|
||||
/// Returns the number of input circuit gates
|
||||
MLIR_CAPI_EXPORTED size_t clientParametersInputsSize(ClientParameters params);
|
||||
|
||||
/// Returns the output circuit gate corresponding to the index
|
||||
///
|
||||
/// - `index` must be valid.
|
||||
MLIR_CAPI_EXPORTED CircuitGate
|
||||
clientParametersOutputCircuitGate(ClientParameters params, size_t index);
|
||||
|
||||
/// Returns the input circuit gate corresponding to the index
|
||||
///
|
||||
/// - `index` must be valid.
|
||||
MLIR_CAPI_EXPORTED CircuitGate
|
||||
clientParametersInputCircuitGate(ClientParameters params, size_t index);
|
||||
|
||||
/// Returns the EncryptionGate of the circuit gate.
|
||||
///
|
||||
/// - The returned gate will be null if the gate does not represent encrypted
|
||||
/// data
|
||||
MLIR_CAPI_EXPORTED EncryptionGate
|
||||
circuitGateEncryptionGate(CircuitGate circuit_gate);
|
||||
|
||||
/// Returns the variance of the encryption gate
|
||||
MLIR_CAPI_EXPORTED double
|
||||
encryptionGateVariance(EncryptionGate encryption_gate);
|
||||
|
||||
/// Returns the Encoding of the encryption gate.
|
||||
MLIR_CAPI_EXPORTED Encoding
|
||||
encryptionGateEncoding(EncryptionGate encryption_gate);
|
||||
|
||||
/// Returns the precision (bit width) of the encoding
|
||||
MLIR_CAPI_EXPORTED uint64_t encodingPrecision(Encoding encoding);
|
||||
|
||||
MLIR_CAPI_EXPORTED void circuitGateDestroy(CircuitGate gate);
|
||||
MLIR_CAPI_EXPORTED void encryptionGateDestroy(EncryptionGate gate);
|
||||
MLIR_CAPI_EXPORTED void encodingDestroy(Encoding encoding);
|
||||
|
||||
/// ********** KeySet CAPI *****************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
MLIR_CAPI_EXPORTED EvaluationKeys keySetGetEvaluationKeys(KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED void keySetDestroy(KeySet keySet);
|
||||
|
||||
/// ********** KeySetCache CAPI ************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySetCache keySetCacheCreate(MlirStringRef cachePath);
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySet
|
||||
keySetCacheLoadOrGenerateKeySet(KeySetCache cache, ClientParameters params,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
MLIR_CAPI_EXPORTED void keySetCacheDestroy(KeySetCache keySetCache);
|
||||
|
||||
/// ********** EvaluationKeys CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef evaluationKeysSerialize(EvaluationKeys keys);
|
||||
|
||||
MLIR_CAPI_EXPORTED EvaluationKeys evaluationKeysUnserialize(BufferRef buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
|
||||
|
||||
/// ********** LambdaArgument CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU8(
|
||||
const uint8_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU16(
|
||||
const uint16_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU32(
|
||||
const uint32_t *data, const int64_t *dims, size_t rank);
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(
|
||||
const uint64_t *data, const int64_t *dims, size_t rank);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg,
|
||||
uint64_t *buffer);
|
||||
MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED int64_t
|
||||
lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg,
|
||||
int64_t *buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicArguments
|
||||
lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber,
|
||||
ClientParameters params, KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED void lambdaArgumentDestroy(LambdaArgument lambdaArg);
|
||||
|
||||
/// ********** PublicArguments CAPI ********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef publicArgumentsSerialize(PublicArguments args);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicArguments
|
||||
publicArgumentsUnserialize(BufferRef buffer, ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs);
|
||||
|
||||
/// ********** PublicResult CAPI ***********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument publicResultDecrypt(PublicResult publicResult,
|
||||
KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED BufferRef publicResultSerialize(PublicResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicResult
|
||||
publicResultUnserialize(BufferRef buffer, ClientParameters params);
|
||||
|
||||
MLIR_CAPI_EXPORTED void publicResultDestroy(PublicResult publicResult);
|
||||
|
||||
/// ********** CompilationFeedback CAPI ****************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED double
|
||||
compilationFeedbackGetComplexity(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED double
|
||||
compilationFeedbackGetPError(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED double
|
||||
compilationFeedbackGetGlobalPError(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalSecretKeysSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalBootstrapKeysSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalKeyswitchKeysSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalInputsSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
compilationFeedbackGetTotalOutputsSize(CompilationFeedback feedback);
|
||||
|
||||
MLIR_CAPI_EXPORTED void
|
||||
compilationFeedbackDestroy(CompilationFeedback feedback);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CONCRETELANG_C_SUPPORT_COMPILER_ENGINE_H
|
||||
@@ -6,10 +6,8 @@
|
||||
#ifndef CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
|
||||
#define CONCRETELANG_BINDINGS_PYTHON_COMPILER_ENGINE_H
|
||||
|
||||
#include "concretelang/Common/Compat.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/LibrarySupport.h"
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
/// MLIR_CAPI_EXPORTED is used here throughout the API, because of the way the
|
||||
@@ -29,42 +27,6 @@ struct executionArguments {
|
||||
};
|
||||
typedef struct executionArguments executionArguments;
|
||||
|
||||
// JIT Support bindings ///////////////////////////////////////////////////////
|
||||
|
||||
struct JITSupport_Py {
|
||||
mlir::concretelang::JITSupport support;
|
||||
};
|
||||
typedef struct JITSupport_Py JITSupport_Py;
|
||||
|
||||
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITSupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile_module(
|
||||
JITSupport_Py support, mlir::ModuleOp module,
|
||||
mlir::concretelang::CompilationOptions options,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> cctx);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
jit_load_compilation_feedback(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
struct LibrarySupport_Py {
|
||||
@@ -78,17 +40,17 @@ library_support(const char *outputPath, const char *runtimeLibraryPath,
|
||||
bool generateClientParameters, bool generateCompilationFeedback,
|
||||
bool generateCppHeader);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibrarySupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile_module(
|
||||
LibrarySupport_Py support, mlir::ModuleOp module,
|
||||
mlir::concretelang::CompilationOptions options,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> cctx);
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibrarySupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::ClientParameters
|
||||
library_load_client_parameters(LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
|
||||
@@ -98,7 +60,8 @@ library_load_compilation_feedback(
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &);
|
||||
mlir::concretelang::LibraryCompilationResult &,
|
||||
bool useSimulation);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
library_server_call(LibrarySupport_Py support,
|
||||
@@ -115,7 +78,7 @@ MLIR_CAPI_EXPORTED std::string
|
||||
library_get_shared_lib_path(LibrarySupport_Py support);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_client_parameters_path(LibrarySupport_Py support);
|
||||
library_get_program_info_path(LibrarySupport_Py support);
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
|
||||
@@ -130,28 +93,30 @@ encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args);
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult);
|
||||
|
||||
// Serialization ////////////////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::ClientParameters
|
||||
clientParametersUnserialize(const std::string &json);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms);
|
||||
clientParametersSerialize(concretelang::clientlib::ClientParameters ¶ms);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
publicArgumentsUnserialize(
|
||||
mlir::concretelang::ClientParameters &clientParameters,
|
||||
concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize(
|
||||
concretelang::clientlib::PublicArguments &publicArguments);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
|
||||
const std::string &buffer);
|
||||
publicResultUnserialize(
|
||||
concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &buffer);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult);
|
||||
@@ -174,6 +139,22 @@ valueUnserialize(const std::string &buffer);
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::ValueExporter createValueExporter(
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::ClientParameters &clientParameters);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::SimulatedValueExporter
|
||||
createSimulatedValueExporter(
|
||||
concretelang::clientlib::ClientParameters &clientParameters);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::ValueDecrypter createValueDecrypter(
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::ClientParameters &clientParameters);
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::SimulatedValueDecrypter
|
||||
createSimulatedValueDecrypter(
|
||||
concretelang::clientlib::ClientParameters &clientParameters);
|
||||
|
||||
/// Parse then print a textual representation of an MLIR module
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
|
||||
#define CONCRETELANG_CLIENTLIB_CLIENT_LAMBDA_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
using scalar_in = uint8_t;
|
||||
using scalar_out = uint64_t;
|
||||
using tensor1_in = std::vector<scalar_in>;
|
||||
using tensor2_in = std::vector<std::vector<scalar_in>>;
|
||||
using tensor3_in = std::vector<std::vector<std::vector<scalar_in>>>;
|
||||
using tensor1_out = std::vector<scalar_out>;
|
||||
using tensor2_out = std::vector<std::vector<scalar_out>>;
|
||||
using tensor3_out = std::vector<std::vector<std::vector<scalar_out>>>;
|
||||
|
||||
/// Low-level class to create the client side view of a FHE function.
|
||||
class ClientLambda {
|
||||
public:
|
||||
virtual ~ClientLambda() = default;
|
||||
|
||||
/// Construct a ClientLambda from a ClientParameter file.
|
||||
static outcome::checked<ClientLambda, StringError> load(std::string funcName,
|
||||
std::string jsonPath);
|
||||
|
||||
/// Generate or get from cache a KeySet suitable for this ClientLambda
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
keySet(std::shared_ptr<KeySetCache> optionalCache, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
decryptReturnedValues(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
decryptReturnedScalar(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
decryptReturnedTensor1(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
decryptReturnedTensor2(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
decryptReturnedTensor3(KeySet &keySet, PublicResult &result);
|
||||
|
||||
public:
|
||||
ClientParameters clientParameters;
|
||||
};
|
||||
|
||||
template <typename Result>
|
||||
outcome::checked<Result, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
PublicResult &result);
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
class TypedClientLambda : public ClientLambda {
|
||||
|
||||
public:
|
||||
static outcome::checked<TypedClientLambda<Result, Args...>, StringError>
|
||||
load(std::string funcName, std::string jsonPath) {
|
||||
OUTCOME_TRY(auto lambda, ClientLambda::load(funcName, jsonPath));
|
||||
return TypedClientLambda(lambda);
|
||||
}
|
||||
|
||||
/// Emit a call on this lambda to a binary ostream.
|
||||
/// The ostream is responsible for transporting the call to a
|
||||
/// ServerLambda::real_call_write function. ostream must be in binary mode
|
||||
/// std::ios_base::openmode::binary
|
||||
outcome::checked<void, StringError>
|
||||
serializeCall(Args... args, KeySet &keySet, std::ostream &ostream) {
|
||||
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
|
||||
return publicArguments->serialize(ostream);
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
publicArguments(Args... args, KeySet &keySet) {
|
||||
OUTCOME_TRY(
|
||||
auto clientArguments,
|
||||
EncryptedArguments::create(/*simulation*/ false, keySet, args...));
|
||||
|
||||
return clientArguments->exportPublicArguments(clientParameters);
|
||||
}
|
||||
|
||||
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,
|
||||
PublicResult &result) {
|
||||
return topLevelDecryptResult<Result>((*this), keySet, result);
|
||||
}
|
||||
|
||||
TypedClientLambda(ClientLambda &lambda) : ClientLambda(lambda) {
|
||||
// TODO: check parameter types
|
||||
// TODO: add static check on types vs lambda inputs/outpus
|
||||
}
|
||||
|
||||
protected:
|
||||
// Workaround, gcc 6 does not support partial template specialisation in class
|
||||
template <typename Result_>
|
||||
friend outcome::checked<Result_, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
PublicResult &result);
|
||||
};
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
|
||||
PublicResult &result);
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
PublicResult &result);
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
PublicResult &result);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,89 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_REFACTORED_H
|
||||
#define CONCRETELANG_CLIENTLIB_REFACTORED_H
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Csprng.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Common/Transformers.h"
|
||||
#include "concretelang/Common/Values.h"
|
||||
|
||||
using concretelang::error::Result;
|
||||
using concretelang::keysets::ClientKeyset;
|
||||
using concretelang::transformers::InputTransformer;
|
||||
using concretelang::transformers::OutputTransformer;
|
||||
using concretelang::transformers::TransformerFactory;
|
||||
using concretelang::values::TransportValue;
|
||||
using concretelang::values::Value;
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
class ClientCircuit {
|
||||
|
||||
public:
|
||||
static Result<ClientCircuit>
|
||||
create(const Message<concreteprotocol::CircuitInfo> &info,
|
||||
const ClientKeyset &keyset, std::shared_ptr<CSPRNG> csprng,
|
||||
bool useSimulation = false);
|
||||
|
||||
Result<TransportValue> prepareInput(Value arg, size_t pos);
|
||||
|
||||
Result<Value> processOutput(TransportValue result, size_t pos);
|
||||
|
||||
std::string getName();
|
||||
|
||||
const Message<concreteprotocol::CircuitInfo> &getCircuitInfo();
|
||||
|
||||
private:
|
||||
ClientCircuit() = delete;
|
||||
ClientCircuit(const Message<concreteprotocol::CircuitInfo> &circuitInfo,
|
||||
std::vector<InputTransformer> inputTransformers,
|
||||
std::vector<OutputTransformer> outputTransformers)
|
||||
: circuitInfo(circuitInfo), inputTransformers(inputTransformers),
|
||||
outputTransformers(outputTransformers){};
|
||||
|
||||
private:
|
||||
Message<concreteprotocol::CircuitInfo> circuitInfo;
|
||||
std::vector<InputTransformer> inputTransformers;
|
||||
std::vector<OutputTransformer> outputTransformers;
|
||||
};
|
||||
|
||||
/// Contains all the context to generate inputs for a server call by the
|
||||
/// server lib.
|
||||
class ClientProgram {
|
||||
public:
|
||||
/// Generates a fresh client program with fresh keyset on the first use.
|
||||
static Result<ClientProgram>
|
||||
create(const Message<concreteprotocol::ProgramInfo> &info,
|
||||
const ClientKeyset &keyset, std::shared_ptr<CSPRNG> csprng,
|
||||
bool useSimulation = false);
|
||||
|
||||
/// Returns a reference to the named client circuit if it exists.
|
||||
Result<ClientCircuit> getClientCircuit(std::string circuitName);
|
||||
|
||||
private:
|
||||
ClientProgram() = default;
|
||||
|
||||
private:
|
||||
std::vector<ClientCircuit> circuits;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,350 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
#define CONCRETELANG_CLIENTLIB_CLIENTPARAMETERS_H_
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
#include <llvm/Support/JSON.h>
|
||||
|
||||
namespace concretelang {
|
||||
|
||||
inline size_t bitWidthAsWord(size_t exactBitWidth) {
|
||||
if (exactBitWidth <= 8)
|
||||
return 8;
|
||||
if (exactBitWidth <= 16)
|
||||
return 16;
|
||||
if (exactBitWidth <= 32)
|
||||
return 32;
|
||||
if (exactBitWidth <= 64)
|
||||
return 64;
|
||||
assert(false && "Bit witdh > 64 not supported");
|
||||
}
|
||||
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
const uint64_t SMALL_KEY = 1;
|
||||
const uint64_t BIG_KEY = 0;
|
||||
|
||||
const std::string CLIENT_PARAMETERS_EXT = ".concrete.params.json";
|
||||
|
||||
typedef uint64_t DecompositionLevelCount;
|
||||
typedef uint64_t DecompositionBaseLog;
|
||||
typedef uint64_t PolynomialSize;
|
||||
typedef uint64_t Precision;
|
||||
typedef double Variance;
|
||||
typedef std::vector<int64_t> CRTDecomposition;
|
||||
|
||||
typedef uint64_t LweDimension;
|
||||
typedef uint64_t GlweDimension;
|
||||
|
||||
typedef uint64_t LweSecretKeyID;
|
||||
struct LweSecretKeyParam {
|
||||
LweDimension dimension;
|
||||
|
||||
void hash(size_t &seed);
|
||||
inline uint64_t lweDimension() { return dimension; }
|
||||
inline uint64_t lweSize() { return dimension + 1; }
|
||||
inline uint64_t byteSize() { return lweSize() * 8; }
|
||||
};
|
||||
static bool operator==(const LweSecretKeyParam &lhs,
|
||||
const LweSecretKeyParam &rhs) {
|
||||
return lhs.dimension == rhs.dimension;
|
||||
}
|
||||
|
||||
typedef uint64_t BootstrapKeyID;
|
||||
struct BootstrapKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
Variance variance;
|
||||
PolynomialSize polynomialSize;
|
||||
LweDimension inputLweDimension;
|
||||
|
||||
void hash(size_t &seed);
|
||||
|
||||
uint64_t byteSize(uint64_t inputLweSize, uint64_t outputLweSize) {
|
||||
return inputLweSize * level * (glweDimension + 1) * (glweDimension + 1) *
|
||||
outputLweSize * 8;
|
||||
}
|
||||
};
|
||||
static inline bool operator==(const BootstrapKeyParam &lhs,
|
||||
const BootstrapKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.glweDimension == rhs.glweDimension && lhs.variance == rhs.variance;
|
||||
}
|
||||
|
||||
typedef uint64_t KeyswitchKeyID;
|
||||
struct KeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
|
||||
size_t byteSize(size_t inputLweSize, size_t outputLweSize) {
|
||||
return level * inputLweSize * outputLweSize * 8;
|
||||
}
|
||||
};
|
||||
static inline bool operator==(const KeyswitchKeyParam &lhs,
|
||||
const KeyswitchKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.variance == rhs.variance;
|
||||
}
|
||||
|
||||
typedef uint64_t PackingKeyswitchKeyID;
|
||||
struct PackingKeyswitchKeyParam {
|
||||
LweSecretKeyID inputSecretKeyID;
|
||||
LweSecretKeyID outputSecretKeyID;
|
||||
DecompositionLevelCount level;
|
||||
DecompositionBaseLog baseLog;
|
||||
GlweDimension glweDimension;
|
||||
PolynomialSize polynomialSize;
|
||||
LweDimension inputLweDimension;
|
||||
Variance variance;
|
||||
|
||||
void hash(size_t &seed);
|
||||
};
|
||||
static inline bool operator==(const PackingKeyswitchKeyParam &lhs,
|
||||
const PackingKeyswitchKeyParam &rhs) {
|
||||
return lhs.inputSecretKeyID == rhs.inputSecretKeyID &&
|
||||
lhs.outputSecretKeyID == rhs.outputSecretKeyID &&
|
||||
lhs.level == rhs.level && lhs.baseLog == rhs.baseLog &&
|
||||
lhs.glweDimension == rhs.glweDimension &&
|
||||
lhs.polynomialSize == rhs.polynomialSize &&
|
||||
lhs.variance == lhs.variance &&
|
||||
lhs.inputLweDimension == rhs.inputLweDimension;
|
||||
}
|
||||
|
||||
struct Encoding {
|
||||
Precision precision;
|
||||
CRTDecomposition crt;
|
||||
bool isSigned;
|
||||
};
|
||||
static inline bool operator==(const Encoding &lhs, const Encoding &rhs) {
|
||||
return lhs.precision == rhs.precision && lhs.isSigned == rhs.isSigned;
|
||||
}
|
||||
|
||||
struct EncryptionGate {
|
||||
LweSecretKeyID secretKeyID;
|
||||
Variance variance;
|
||||
Encoding encoding;
|
||||
};
|
||||
static inline bool operator==(const EncryptionGate &lhs,
|
||||
const EncryptionGate &rhs) {
|
||||
return lhs.secretKeyID == rhs.secretKeyID && lhs.variance == rhs.variance &&
|
||||
lhs.encoding == rhs.encoding;
|
||||
}
|
||||
|
||||
struct CircuitGateShape {
|
||||
/// Width of the scalar value
|
||||
uint64_t width;
|
||||
/// Dimensions of the tensor, empty if scalar
|
||||
std::vector<int64_t> dimensions;
|
||||
/// Size of the buffer containing the tensor
|
||||
uint64_t size;
|
||||
// Indicated whether elements are signed
|
||||
bool sign;
|
||||
};
|
||||
static inline bool operator==(const CircuitGateShape &lhs,
|
||||
const CircuitGateShape &rhs) {
|
||||
return lhs.width == rhs.width && lhs.dimensions == rhs.dimensions &&
|
||||
lhs.size == rhs.size;
|
||||
}
|
||||
|
||||
struct ChunkInfo {
|
||||
/// total number of bits used for the chunk including the carry.
|
||||
/// size should be at least width + 1
|
||||
unsigned int size;
|
||||
/// number of bits used for the chunk excluding the carry
|
||||
unsigned int width;
|
||||
};
|
||||
static inline bool operator==(const ChunkInfo &lhs, const ChunkInfo &rhs) {
|
||||
return lhs.width == rhs.width && lhs.size == rhs.size;
|
||||
}
|
||||
|
||||
struct CircuitGate {
|
||||
std::optional<EncryptionGate> encryption;
|
||||
CircuitGateShape shape;
|
||||
std::optional<ChunkInfo> chunkInfo;
|
||||
|
||||
bool isEncrypted() { return encryption.has_value(); }
|
||||
|
||||
/// byteSize returns the size in bytes for this gate.
|
||||
size_t byteSize(std::vector<LweSecretKeyParam> secretKeys) {
|
||||
auto width = shape.width;
|
||||
auto numElts = shape.size == 0 ? 1 : shape.size;
|
||||
if (isEncrypted()) {
|
||||
assert(encryption->secretKeyID < secretKeys.size());
|
||||
auto skParam = secretKeys[encryption->secretKeyID];
|
||||
return 8 * skParam.lweSize() * numElts;
|
||||
}
|
||||
width = bitWidthAsWord(width) / 8;
|
||||
return width * numElts;
|
||||
}
|
||||
};
|
||||
static inline bool operator==(const CircuitGate &lhs, const CircuitGate &rhs) {
|
||||
return lhs.encryption == rhs.encryption && lhs.shape == rhs.shape &&
|
||||
lhs.chunkInfo == rhs.chunkInfo;
|
||||
}
|
||||
|
||||
struct ClientParameters {
|
||||
std::vector<LweSecretKeyParam> secretKeys;
|
||||
std::vector<BootstrapKeyParam> bootstrapKeys;
|
||||
std::vector<KeyswitchKeyParam> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKeyParam> packingKeyswitchKeys;
|
||||
std::vector<CircuitGate> inputs;
|
||||
std::vector<CircuitGate> outputs;
|
||||
std::string functionName;
|
||||
|
||||
size_t hash();
|
||||
|
||||
static outcome::checked<std::vector<ClientParameters>, StringError>
|
||||
load(std::string path);
|
||||
|
||||
static std::string getClientParametersPath(std::string path);
|
||||
|
||||
outcome::checked<CircuitGate, StringError> input(size_t pos) {
|
||||
if (pos >= inputs.size()) {
|
||||
return StringError("input gate ") << pos << " didn't exists";
|
||||
}
|
||||
return inputs[pos];
|
||||
}
|
||||
|
||||
outcome::checked<CircuitGate, StringError> ouput(size_t pos) {
|
||||
if (pos >= outputs.size()) {
|
||||
return StringError("output gate ") << pos << " didn't exists";
|
||||
}
|
||||
return outputs[pos];
|
||||
}
|
||||
|
||||
outcome::checked<LweSecretKeyParam, StringError>
|
||||
lweSecretKeyParam(CircuitGate gate) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
return StringError("gate is not encrypted");
|
||||
}
|
||||
assert(gate.encryption->secretKeyID < secretKeys.size());
|
||||
auto secretKey = secretKeys[gate.encryption->secretKeyID];
|
||||
return secretKey;
|
||||
}
|
||||
|
||||
/// bufferSize returns the size of the whole buffer of a gate.
|
||||
int64_t bufferSize(CircuitGate gate) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
// Value is not encrypted just returns the tensor size
|
||||
return gate.shape.size;
|
||||
}
|
||||
auto shapeSize = gate.shape.size == 0 ? 1 : gate.shape.size;
|
||||
// Size of the ciphertext
|
||||
return shapeSize * lweBufferSize(gate);
|
||||
}
|
||||
|
||||
/// lweBufferSize returns the size of one ciphertext of a gate.
|
||||
int64_t lweBufferSize(CircuitGate gate) {
|
||||
assert(gate.encryption.has_value());
|
||||
auto nbBlocks = gate.encryption->encoding.crt.size();
|
||||
nbBlocks = nbBlocks == 0 ? 1 : nbBlocks;
|
||||
|
||||
auto param = lweSecretKeyParam(gate);
|
||||
assert(param.has_value());
|
||||
return param.value().lweSize() * nbBlocks;
|
||||
}
|
||||
|
||||
/// bufferShape returns the shape of the tensor for the given gate. It returns
|
||||
/// the shape used at low-level, i.e. contains the dimensions for ciphertexts
|
||||
/// (if not in simulation).
|
||||
std::vector<int64_t> bufferShape(CircuitGate gate, bool simulation = false) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
// Value is not encrypted just returns the tensor shape
|
||||
return gate.shape.dimensions;
|
||||
}
|
||||
auto lweSecreteKeyParam = lweSecretKeyParam(gate);
|
||||
assert(lweSecreteKeyParam.has_value());
|
||||
|
||||
// Copy the shape
|
||||
std::vector<int64_t> shape(gate.shape.dimensions);
|
||||
|
||||
auto crt = gate.encryption->encoding.crt;
|
||||
|
||||
// CRT case: Add one dimension equals to the number of blocks
|
||||
if (!crt.empty()) {
|
||||
shape.push_back(crt.size());
|
||||
}
|
||||
if (!simulation) {
|
||||
// Add one dimension for the size of ciphertext(s)
|
||||
shape.push_back(lweSecreteKeyParam.value().lweSize());
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
};
|
||||
|
||||
static inline bool operator==(const ClientParameters &lhs,
|
||||
const ClientParameters &rhs) {
|
||||
return lhs.secretKeys == rhs.secretKeys &&
|
||||
lhs.bootstrapKeys == rhs.bootstrapKeys &&
|
||||
lhs.keyswitchKeys == rhs.keyswitchKeys && lhs.inputs == lhs.inputs &&
|
||||
lhs.outputs == lhs.outputs;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const LweSecretKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, LweSecretKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const BootstrapKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, BootstrapKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const KeyswitchKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, KeyswitchKeyParam &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &);
|
||||
bool fromJSON(const llvm::json::Value, PackingKeyswitchKeyParam &,
|
||||
llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const Encoding &);
|
||||
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const EncryptionGate &);
|
||||
bool fromJSON(const llvm::json::Value, EncryptionGate &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGateShape &);
|
||||
bool fromJSON(const llvm::json::Value, CircuitGateShape &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGate &);
|
||||
bool fromJSON(const llvm::json::Value, CircuitGate &, llvm::json::Path);
|
||||
|
||||
llvm::json::Value toJSON(const ClientParameters &);
|
||||
bool fromJSON(const llvm::json::Value, ClientParameters &, llvm::json::Path);
|
||||
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
ClientParameters cp) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(cp));
|
||||
}
|
||||
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_string_ostream &OS,
|
||||
ClientParameters cp) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(cp));
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,189 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
|
||||
#define CONCRETELANG_CLIENTLIB_ENCRYPTED_ARGS_H
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/ClientLib/ValueExporter.h"
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class PublicArguments;
|
||||
|
||||
/// Temporary object used to hold and encrypt parameters before calling a
|
||||
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
|
||||
/// Otherwise convert it to a PublicArguments and use
|
||||
/// serializeCall(PublicArguments, KeySet).
|
||||
class EncryptedArguments {
|
||||
|
||||
public:
|
||||
EncryptedArguments(bool simulation = false) : simulation(simulation) {}
|
||||
|
||||
/// Encrypts args thanks the given KeySet and pack the encrypted arguments
|
||||
/// to an EncryptedArguments
|
||||
template <typename... Args>
|
||||
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
|
||||
create(bool simulation, KeySet &keySet, Args... args) {
|
||||
auto encryptedArgs = std::make_unique<EncryptedArguments>(simulation);
|
||||
OUTCOME_TRYV(encryptedArgs->pushArgs(keySet, args...));
|
||||
return std::move(encryptedArgs);
|
||||
}
|
||||
|
||||
template <typename ArgT>
|
||||
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
|
||||
create(bool simulation, KeySet &keySet, const llvm::ArrayRef<ArgT> args) {
|
||||
auto encryptedArgs = EncryptedArguments::empty(simulation);
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
OUTCOME_TRYV(encryptedArgs->pushArg(args[i], keySet));
|
||||
}
|
||||
OUTCOME_TRYV(encryptedArgs->checkAllArgs(keySet));
|
||||
return std::move(encryptedArgs);
|
||||
}
|
||||
|
||||
static std::unique_ptr<EncryptedArguments> empty(bool simulation = false) {
|
||||
return std::make_unique<EncryptedArguments>(simulation);
|
||||
}
|
||||
|
||||
bool isSimulated() { return simulation; }
|
||||
|
||||
/// Export encrypted arguments as public arguments, reset the encrypted
|
||||
/// arguments, i.e. move all buffers to the PublicArguments and reset the
|
||||
/// positional counter.
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
exportPublicArguments(ClientParameters clientParameters);
|
||||
|
||||
/// Check that all arguments as been pushed.
|
||||
// TODO: Remove public method here
|
||||
outcome::checked<void, StringError> checkAllArgs(KeySet &keySet);
|
||||
outcome::checked<void, StringError> checkAllArgs(ClientParameters ¶ms);
|
||||
|
||||
public:
|
||||
/// Add a uint64_t scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint64_t arg, KeySet &keySet) {
|
||||
auto exporter = getExporter(keySet);
|
||||
OUTCOME_TRY(auto value, exporter->exportValue(arg, values.size()));
|
||||
values.push_back(std::move(value));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
/// Add a vector-tensor argument.
|
||||
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
|
||||
KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(),
|
||||
llvm::ArrayRef<int64_t>{(int64_t)arg.size()}, keySet);
|
||||
}
|
||||
|
||||
/// Add a 1D tensor argument with data and size of the dimension.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(const T *data, int64_t dim1,
|
||||
KeySet &keySet) {
|
||||
return pushArg(std::vector<uint8_t>(data, data + dim1), keySet);
|
||||
}
|
||||
|
||||
/// Add a 1D tensor argument.
|
||||
template <size_t size>
|
||||
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
|
||||
KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size},
|
||||
keySet);
|
||||
}
|
||||
|
||||
/// Add a 2D tensor argument.
|
||||
template <size_t size0, size_t size1>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<uint8_t, size1>, size0> arg, KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(), llvm::ArrayRef<int64_t>{size0, size1},
|
||||
keySet);
|
||||
}
|
||||
|
||||
/// Add a 3D tensor argument.
|
||||
template <size_t size0, size_t size1, size_t size2>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
|
||||
KeySet &keySet) {
|
||||
return pushArg((uint8_t *)arg.data(),
|
||||
llvm::ArrayRef<int64_t>{size0, size1, size2}, keySet);
|
||||
}
|
||||
|
||||
// Generalize by computing shape by template recursion
|
||||
|
||||
/// Set a argument at the given pos as a 1D tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError> pushArg(T *data, int64_t dim1,
|
||||
KeySet &keySet) {
|
||||
return pushArg<T>(data, llvm::ArrayRef<int64_t>(&dim1, 1), keySet);
|
||||
}
|
||||
|
||||
/// Set a argument at the given pos as a tensor of T.
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
return pushArg(static_cast<const T *>(data), shape, keySet);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(const T *data, llvm::ArrayRef<int64_t> shape, KeySet &keySet) {
|
||||
auto exporter = getExporter(keySet);
|
||||
OUTCOME_TRY(auto value, exporter->exportValue(data, shape, values.size()));
|
||||
values.push_back(std::move(value));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
/// Recursive case for scalars: extract first scalar argument from
|
||||
/// parameter pack and forward rest
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError> pushArgs(KeySet &keySet, Arg0 arg0,
|
||||
OtherArgs... others) {
|
||||
OUTCOME_TRYV(pushArg(arg0, keySet));
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
/// Recursive case for tensors: extract pointer and size from
|
||||
/// parameter pack and forward rest
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError>
|
||||
pushArgs(KeySet &keySet, Arg0 *arg0, size_t size, OtherArgs... others) {
|
||||
OUTCOME_TRYV(pushArg(arg0, size, keySet));
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
/// Terminal case of pushArgs
|
||||
outcome::checked<void, StringError> pushArgs(KeySet &keySet) {
|
||||
return checkAllArgs(keySet);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<ValueExporterInterface> getExporter(KeySet &keySet) {
|
||||
if (isSimulated()) {
|
||||
return std::make_unique<SimulatedValueExporter>(
|
||||
keySet.clientParameters());
|
||||
} else {
|
||||
return std::make_unique<ValueExporter>(keySet, keySet.clientParameters());
|
||||
}
|
||||
}
|
||||
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<ScalarOrTensorData> values;
|
||||
/// Whether it a simulates an encrypted argument or not
|
||||
bool simulation;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,194 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
|
||||
#define CONCRETELANG_CLIENTLIB_EVALUATION_KEYS_H_
|
||||
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
struct Csprng;
|
||||
struct CsprngVtable;
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
class CSPRNG {
|
||||
public:
|
||||
struct Csprng *ptr;
|
||||
const struct CsprngVtable *vtable;
|
||||
|
||||
CSPRNG() = delete;
|
||||
CSPRNG(CSPRNG &) = delete;
|
||||
|
||||
CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) {
|
||||
assert(ptr != nullptr);
|
||||
other.ptr = nullptr;
|
||||
};
|
||||
|
||||
CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){};
|
||||
};
|
||||
|
||||
class ConcreteCSPRNG : public CSPRNG {
|
||||
public:
|
||||
ConcreteCSPRNG(__uint128_t seed);
|
||||
ConcreteCSPRNG() = delete;
|
||||
ConcreteCSPRNG(ConcreteCSPRNG &) = delete;
|
||||
ConcreteCSPRNG(ConcreteCSPRNG &&other);
|
||||
~ConcreteCSPRNG();
|
||||
};
|
||||
|
||||
/// @brief LweSecretKey implements tools for manipulating lwe secret key on
|
||||
/// client.
|
||||
class LweSecretKey {
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
LweSecretKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweSecretKey() = delete;
|
||||
LweSecretKey(LweSecretKeyParam ¶meters, CSPRNG &csprng);
|
||||
LweSecretKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
LweSecretKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
/// @brief Encrypt the plaintext to the lwe ciphertext buffer.
|
||||
void encrypt(uint64_t *ciphertext, uint64_t plaintext, double variance,
|
||||
CSPRNG &csprng) const;
|
||||
|
||||
/// @brief Decrypt the ciphertext to the plaintext
|
||||
void decrypt(const uint64_t *ciphertext, uint64_t &plaintext) const;
|
||||
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
LweSecretKeyParam parameters() const { return this->_parameters; }
|
||||
|
||||
/// @brief Returns the lwe dimension of the secret key.
|
||||
size_t dimension() const { return parameters().dimension; }
|
||||
};
|
||||
|
||||
/// @brief LweKeyswitchKey implements tools for manipulating keyswitch key on
|
||||
/// client.
|
||||
class LweKeyswitchKey {
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
KeyswitchKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweKeyswitchKey() = delete;
|
||||
LweKeyswitchKey(KeyswitchKeyParam ¶meters, LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng);
|
||||
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
KeyswitchKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
KeyswitchKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
/// @brief LweBootstrapKey implements tools for manipulating bootstrap key on
|
||||
/// client.
|
||||
class LweBootstrapKey {
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
BootstrapKeyParam _parameters;
|
||||
|
||||
public:
|
||||
LweBootstrapKey() = delete;
|
||||
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
BootstrapKeyParam ¶meters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
LweBootstrapKey(BootstrapKeyParam ¶meters, LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng);
|
||||
|
||||
///// @brief Returns the buffer that hold the bootstrap key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the bootsrap key.
|
||||
BootstrapKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
/// @brief PackingKeyswitchKey implements tools for manipulating privat packing
|
||||
/// keyswitch key on client.
|
||||
class PackingKeyswitchKey {
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> _buffer;
|
||||
PackingKeyswitchKeyParam _parameters;
|
||||
|
||||
public:
|
||||
PackingKeyswitchKey() = delete;
|
||||
PackingKeyswitchKey(PackingKeyswitchKeyParam ¶meters,
|
||||
LweSecretKey &inputKey, LweSecretKey &outputKey,
|
||||
CSPRNG &csprng);
|
||||
PackingKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
PackingKeyswitchKeyParam parameters)
|
||||
: _buffer(buffer), _parameters(parameters){};
|
||||
|
||||
/// @brief Returns the buffer that hold the keyswitch key.
|
||||
const uint64_t *buffer() const { return _buffer->data(); }
|
||||
size_t size() const { return _buffer->size(); }
|
||||
|
||||
/// @brief Returns the parameters of the keyswicth key.
|
||||
PackingKeyswitchKeyParam parameters() const { return this->_parameters; }
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
/// Evalution keys required for execution.
|
||||
class EvaluationKeys {
|
||||
private:
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
public:
|
||||
EvaluationKeys() = delete;
|
||||
|
||||
EvaluationKeys(const std::vector<LweKeyswitchKey> keyswitchKeys,
|
||||
const std::vector<LweBootstrapKey> bootstrapKeys,
|
||||
const std::vector<PackingKeyswitchKey> packingKeyswitchKeys)
|
||||
: keyswitchKeys(keyswitchKeys), bootstrapKeys(bootstrapKeys),
|
||||
packingKeyswitchKeys(packingKeyswitchKeys) {}
|
||||
|
||||
const LweKeyswitchKey &getKeyswitchKey(size_t id) const {
|
||||
return this->keyswitchKeys[id];
|
||||
}
|
||||
const std::vector<LweKeyswitchKey> getKeyswitchKeys() const {
|
||||
return this->keyswitchKeys;
|
||||
}
|
||||
|
||||
const LweBootstrapKey &getBootstrapKey(size_t id) const {
|
||||
return bootstrapKeys[id];
|
||||
}
|
||||
const std::vector<LweBootstrapKey> getBootstrapKeys() const {
|
||||
return this->bootstrapKeys;
|
||||
}
|
||||
|
||||
const PackingKeyswitchKey &getPackingKeyswitchKey(size_t id) const {
|
||||
return this->packingKeyswitchKeys[id];
|
||||
};
|
||||
|
||||
const std::vector<PackingKeyswitchKey> getPackingKeyswitchKeys() const {
|
||||
return this->packingKeyswitchKeys;
|
||||
}
|
||||
};
|
||||
|
||||
// =============================================
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,128 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_KEYSET_H_
|
||||
#define CONCRETELANG_CLIENTLIB_KEYSET_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class KeySet {
|
||||
public:
|
||||
KeySet(ClientParameters clientParameters, CSPRNG &&csprng)
|
||||
: csprng(std::move(csprng)), _clientParameters(clientParameters){};
|
||||
KeySet(KeySet &other) = delete;
|
||||
|
||||
/// Generate a KeySet from a ClientParameters specification.
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters clientParameters, CSPRNG &&csprng);
|
||||
|
||||
/// Create a KeySet from a set of given keys
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError> fromKeys(
|
||||
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
|
||||
std::vector<LweBootstrapKey> bootstrapKeys,
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys,
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng);
|
||||
|
||||
/// Returns the ClientParameters associated with the KeySet.
|
||||
ClientParameters clientParameters() const { return _clientParameters; }
|
||||
|
||||
// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
bool isInputEncrypted(size_t pos);
|
||||
|
||||
/// allocate a lwe ciphertext buffer for the argument at argPos, set the size
|
||||
/// of the allocated buffer.
|
||||
outcome::checked<void, StringError>
|
||||
allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size);
|
||||
|
||||
/// encrypt the input to the ciphertext for the argument at argPos.
|
||||
outcome::checked<void, StringError>
|
||||
encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
|
||||
|
||||
/// isOuputEncrypted return true if the output at the given pos is encrypted.
|
||||
bool isOutputEncrypted(size_t pos);
|
||||
|
||||
/// decrypt the ciphertext to the output for the argument at argPos.
|
||||
outcome::checked<void, StringError>
|
||||
decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output);
|
||||
|
||||
size_t numInputs() { return inputs.size(); }
|
||||
size_t numOutputs() { return outputs.size(); }
|
||||
|
||||
CircuitGate inputGate(size_t pos) { return std::get<0>(inputs[pos]); }
|
||||
CircuitGate outputGate(size_t pos) { return std::get<0>(outputs[pos]); }
|
||||
|
||||
/// @brief evaluationKeys returns the evaluation keys associate to this client
|
||||
/// keyset. Those evaluations keys can be safely shared publicly
|
||||
EvaluationKeys evaluationKeys();
|
||||
|
||||
const std::vector<LweSecretKey> &getSecretKeys() const;
|
||||
|
||||
const std::vector<LweBootstrapKey> &getBootstrapKeys() const;
|
||||
|
||||
const std::vector<LweKeyswitchKey> &getKeyswitchKeys() const;
|
||||
|
||||
const std::vector<PackingKeyswitchKey> &getPackingKeyswitchKeys() const;
|
||||
|
||||
protected:
|
||||
outcome::checked<void, StringError>
|
||||
generateSecretKey(LweSecretKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateBootstrapKey(BootstrapKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generateKeyswitchKey(KeyswitchKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
generatePackingKeyswitchKey(PackingKeyswitchKeyParam param);
|
||||
|
||||
outcome::checked<void, StringError> generateKeysFromParams();
|
||||
|
||||
outcome::checked<void, StringError> setupEncryptionMaterial();
|
||||
|
||||
friend class KeySetCache;
|
||||
|
||||
private:
|
||||
CSPRNG csprng;
|
||||
|
||||
///////////////////////////////////////////////
|
||||
// Keys mappings
|
||||
std::vector<LweSecretKey> secretKeys;
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
outcome::checked<LweSecretKey, StringError> findLweSecretKey(LweSecretKeyID);
|
||||
|
||||
///////////////////////////////////////////////
|
||||
// Convenient positional mapping between positional gate en secret key
|
||||
typedef std::vector<std::pair<CircuitGate, std::optional<LweSecretKey>>>
|
||||
SecretKeyGateMapping;
|
||||
outcome::checked<SecretKeyGateMapping, StringError>
|
||||
mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates);
|
||||
|
||||
SecretKeyGateMapping inputs;
|
||||
SecretKeyGateMapping outputs;
|
||||
|
||||
clientlib::ClientParameters _clientParameters;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,43 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
|
||||
#define CONCRETELANG_CLIENTLIB_KEYSETCACHE_H_
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
class KeySet;
|
||||
|
||||
class KeySetCache {
|
||||
std::string backingDirectoryPath;
|
||||
|
||||
public:
|
||||
KeySetCache(std::string backingDirectoryPath)
|
||||
: backingDirectoryPath(backingDirectoryPath) {}
|
||||
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(std::shared_ptr<KeySetCache> optionalCache, ClientParameters ¶ms,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
private:
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
loadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb,
|
||||
std::string folderPath);
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb);
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,146 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
|
||||
#define CONCRETELANG_CLIENTLIB_PUBLIC_ARGUMENTS_H
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/ClientLib/ValueDecrypter.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
class ServerLambda;
|
||||
}
|
||||
} // namespace concretelang
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
class JITLambda;
|
||||
}
|
||||
} // namespace mlir
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::clientlib::ValueDecrypter;
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class EncryptedArguments;
|
||||
|
||||
/// PublicArguments will be sended to the server. It includes encrypted
|
||||
/// arguments and public keys.
|
||||
class PublicArguments {
|
||||
public:
|
||||
PublicArguments(const ClientParameters &clientParameters,
|
||||
std::vector<clientlib::SharedScalarOrTensorData> &buffers);
|
||||
~PublicArguments();
|
||||
|
||||
static outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
unserialize(const ClientParameters &expectedParams, std::istream &istream);
|
||||
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
std::vector<SharedScalarOrTensorData> &getArguments() { return arguments; }
|
||||
ClientParameters &getClientParameters() { return clientParameters; }
|
||||
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
friend class ::mlir::concretelang::JITLambda;
|
||||
|
||||
private:
|
||||
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
|
||||
|
||||
ClientParameters clientParameters;
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<SharedScalarOrTensorData> arguments;
|
||||
};
|
||||
|
||||
/// PublicResult is a result of a ServerLambda call which contains encrypted
|
||||
/// results.
|
||||
struct PublicResult {
|
||||
|
||||
PublicResult(const ClientParameters &clientParameters,
|
||||
std::vector<SharedScalarOrTensorData> &&buffers = {})
|
||||
: clientParameters(clientParameters), buffers(std::move(buffers)){};
|
||||
|
||||
PublicResult(PublicResult &) = delete;
|
||||
|
||||
/// @brief Return a value from the PublicResult
|
||||
/// @param argPos The position of the value in the PublicResult
|
||||
/// @return Either the value or an error if there are no value at this
|
||||
/// position
|
||||
outcome::checked<SharedScalarOrTensorData, StringError>
|
||||
getValue(size_t argPos) {
|
||||
if (argPos >= buffers.size()) {
|
||||
return StringError("result #") << argPos << " does not exists";
|
||||
}
|
||||
return buffers[argPos];
|
||||
}
|
||||
|
||||
/// Create a public result from buffers.
|
||||
static std::unique_ptr<PublicResult>
|
||||
fromBuffers(const ClientParameters &clientParameters,
|
||||
std::vector<SharedScalarOrTensorData> &&buffers) {
|
||||
return std::make_unique<PublicResult>(clientParameters, std::move(buffers));
|
||||
}
|
||||
|
||||
/// Unserialize from an input stream inplace.
|
||||
outcome::checked<void, StringError> unserialize(std::istream &istream);
|
||||
/// Unserialize from an input stream returning a new PublicResult.
|
||||
static outcome::checked<std::unique_ptr<PublicResult>, StringError>
|
||||
unserialize(ClientParameters &expectedParams, std::istream &istream) {
|
||||
auto publicResult = std::make_unique<PublicResult>(expectedParams);
|
||||
OUTCOME_TRYV(publicResult->unserialize(istream));
|
||||
return std::move(publicResult);
|
||||
}
|
||||
/// Serialize into an output stream.
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
/// Get the result at `pos` as a scalar. Decryption happens if the
|
||||
/// result is encrypted.
|
||||
template <typename T>
|
||||
outcome::checked<T, StringError> asClearTextScalar(KeySet &keySet,
|
||||
size_t pos) {
|
||||
ValueDecrypter decrypter(keySet, clientParameters);
|
||||
auto &data = buffers[pos].get();
|
||||
return decrypter.template decrypt<T>(data, pos);
|
||||
}
|
||||
|
||||
/// Get the result at `pos` as a vector. Decryption happens if the
|
||||
/// result is encrypted.
|
||||
template <typename T>
|
||||
outcome::checked<std::vector<T>, StringError>
|
||||
asClearTextVector(KeySet &keySet, size_t pos) {
|
||||
ValueDecrypter decrypter(keySet, clientParameters);
|
||||
return decrypter.template decryptTensor<T>(buffers[pos].get(), pos);
|
||||
}
|
||||
|
||||
/// Return the shape of the clear tensor of a result.
|
||||
outcome::checked<std::vector<int64_t>, StringError>
|
||||
asClearTextShape(size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
return gate.shape.dimensions;
|
||||
}
|
||||
|
||||
// private: TODO tmp
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
ClientParameters clientParameters;
|
||||
std::vector<SharedScalarOrTensorData> buffers;
|
||||
};
|
||||
|
||||
/// Helper function to convert from MemRefDescriptor to
|
||||
/// TensorData
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
|
||||
bool is_signed, void *allocated, void *aligned,
|
||||
size_t offset, size_t *sizes, size_t *strides);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,129 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
|
||||
#define CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
// integers are not serialized as binary values even on a binary stream
|
||||
// so we cannot rely on << operator directly
|
||||
template <typename Word>
|
||||
std::ostream &writeWord(std::ostream &ostream, Word word) {
|
||||
ostream.write(reinterpret_cast<char *>(&(word)), sizeof(word));
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
template <typename Size>
|
||||
std::ostream &writeSize(std::ostream &ostream, Size size) {
|
||||
return writeWord(ostream, size);
|
||||
}
|
||||
|
||||
// for sake of symetry
|
||||
template <typename Word>
|
||||
std::istream &readWord(std::istream &istream, Word &word) {
|
||||
istream.read(reinterpret_cast<char *>(&(word)), sizeof(word));
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
template <typename Word>
|
||||
std::istream &readWords(std::istream &istream, Word *words, size_t numWords) {
|
||||
assert(std::numeric_limits<size_t>::max() / sizeof(*words) > numWords);
|
||||
istream.read(reinterpret_cast<char *>(words), sizeof(*words) * numWords);
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
template <typename Size>
|
||||
std::istream &readSize(std::istream &istream, Size &size) {
|
||||
return readWord(istream, size);
|
||||
}
|
||||
|
||||
template <typename Stream> bool incorrectMode(Stream &stream) {
|
||||
auto binary = stream.flags() && std::ios::binary;
|
||||
if (!binary) {
|
||||
stream.setstate(std::ios::failbit);
|
||||
}
|
||||
return !binary;
|
||||
}
|
||||
|
||||
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream);
|
||||
|
||||
outcome::checked<ScalarData, StringError>
|
||||
unserializeScalarData(std::istream &istream);
|
||||
|
||||
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
|
||||
std::ostream &ostream);
|
||||
|
||||
template <typename T>
|
||||
std::ostream &serializeTensorDataRaw(const llvm::ArrayRef<size_t> &dimensions,
|
||||
const llvm::ArrayRef<T> &values,
|
||||
std::ostream &ostream) {
|
||||
|
||||
writeWord<uint64_t>(ostream, dimensions.size());
|
||||
|
||||
for (size_t dim : dimensions)
|
||||
writeWord<int64_t>(ostream, dim);
|
||||
|
||||
writeWord<uint64_t>(ostream, sizeof(T) * 8);
|
||||
writeWord<uint8_t>(ostream, std::is_signed<T>());
|
||||
|
||||
for (T val : values)
|
||||
writeWord(ostream, val);
|
||||
|
||||
return ostream;
|
||||
}
|
||||
|
||||
outcome::checked<TensorData, StringError>
|
||||
unserializeTensorData(std::istream &istream);
|
||||
|
||||
std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
|
||||
std::ostream &ostream);
|
||||
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
unserializeScalarOrTensorData(std::istream &istream);
|
||||
|
||||
std::ostream &serializeVectorOfScalarOrTensorData(
|
||||
const std::vector<SharedScalarOrTensorData> &sotd, std::ostream &ostream);
|
||||
outcome::checked<std::vector<SharedScalarOrTensorData>, StringError>
|
||||
unserializeVectorOfScalarOrTensorData(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
|
||||
LweSecretKey readLweSecretKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweKeyswitchKey &wrappedKsk);
|
||||
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const LweBootstrapKey &wrappedBsk);
|
||||
LweBootstrapKey readLweBootstrapKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const PackingKeyswitchKey &wrappedKsk);
|
||||
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const KeySet &keySet);
|
||||
std::unique_ptr<KeySet> readKeySet(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys);
|
||||
EvaluationKeys readEvaluationKeys(std::istream &istream);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,898 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
#define CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <stddef.h>
|
||||
#include <vector>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
template <size_t N> struct MemRefDescriptor {
|
||||
uint64_t *allocated;
|
||||
uint64_t *aligned;
|
||||
size_t offset;
|
||||
size_t sizes[N];
|
||||
size_t strides[N];
|
||||
};
|
||||
|
||||
using decrypted_scalar_t = std::uint64_t;
|
||||
using decrypted_tensor_1_t = std::vector<decrypted_scalar_t>;
|
||||
using decrypted_tensor_2_t = std::vector<decrypted_tensor_1_t>;
|
||||
using decrypted_tensor_3_t = std::vector<decrypted_tensor_2_t>;
|
||||
|
||||
template <size_t Rank> using encrypted_tensor_t = MemRefDescriptor<Rank>;
|
||||
using encrypted_scalar_t = uint64_t *;
|
||||
using encrypted_scalars_t = uint64_t *;
|
||||
|
||||
// Element types for `TensorData`
|
||||
enum class ElementType { u64, i64, u32, i32, u16, i16, u8, i8 };
|
||||
|
||||
// Returns the width in bits of an integer whose width is a power of
|
||||
// two that can hold values with at most `width` bits
|
||||
static inline constexpr size_t getStorageWidth(size_t width) {
|
||||
if (width > 64)
|
||||
assert(false && "Unsupported scalar width");
|
||||
|
||||
if (width > 32) {
|
||||
return 64;
|
||||
} else if (width > 16) {
|
||||
return 32;
|
||||
} else if (width > 8) {
|
||||
return 16;
|
||||
} else {
|
||||
return 8;
|
||||
}
|
||||
}
|
||||
|
||||
// Translates `sign` and `width` into an `ElementType`.
|
||||
static inline ElementType getElementTypeFromWidthAndSign(size_t width,
|
||||
bool sign) {
|
||||
switch (getStorageWidth(width)) {
|
||||
case 64:
|
||||
return (sign) ? ElementType::i64 : ElementType::u64;
|
||||
case 32:
|
||||
return (sign) ? ElementType::i32 : ElementType::u32;
|
||||
case 16:
|
||||
return (sign) ? ElementType::i16 : ElementType::u16;
|
||||
case 8:
|
||||
return (sign) ? ElementType::i8 : ElementType::u8;
|
||||
default:
|
||||
assert(false && "Unsupported scalar width");
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Returns the number of bits for an element type
|
||||
static constexpr size_t getElementTypeWidth(ElementType t) {
|
||||
switch (t) {
|
||||
case ElementType::u64:
|
||||
case ElementType::i64:
|
||||
return 64;
|
||||
case ElementType::u32:
|
||||
case ElementType::i32:
|
||||
return 32;
|
||||
case ElementType::u16:
|
||||
case ElementType::i16:
|
||||
return 16;
|
||||
case ElementType::u8:
|
||||
case ElementType::i8:
|
||||
return 8;
|
||||
}
|
||||
|
||||
// Cannot happen
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Returns `true` if the element type `t` designates a signed type,
|
||||
// otherwise `false`.
|
||||
static constexpr size_t getElementTypeSignedness(ElementType t) {
|
||||
switch (t) {
|
||||
case ElementType::u64:
|
||||
case ElementType::u32:
|
||||
case ElementType::u16:
|
||||
case ElementType::u8:
|
||||
return false;
|
||||
case ElementType::i64:
|
||||
case ElementType::i32:
|
||||
case ElementType::i16:
|
||||
case ElementType::i8:
|
||||
return true;
|
||||
}
|
||||
|
||||
// Cannot happen
|
||||
return false;
|
||||
}
|
||||
|
||||
// Returns `true` iff the element type `t` designates the smallest
|
||||
// unsigned / signed (depending on `sign`) integer type that can hold
|
||||
// values of up to `width` bits, otherwise false.
|
||||
static inline bool checkElementTypeForWidthAndSign(ElementType t, size_t width,
|
||||
bool sign) {
|
||||
return getElementTypeFromWidthAndSign(getStorageWidth(width), sign) == t;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Constants for the element types used for tensors representing
|
||||
// encrypted data and data after decryption
|
||||
constexpr ElementType EncryptedScalarElementType = ElementType::u64;
|
||||
constexpr size_t EncryptedScalarElementWidth =
|
||||
getElementTypeWidth(EncryptedScalarElementType);
|
||||
|
||||
using EncryptedScalarElement = uint64_t;
|
||||
|
||||
namespace detail {
|
||||
namespace TensorData {
|
||||
|
||||
// Union used to store the pointer to the actual data of an instance
|
||||
// of `TensorData`. Values are stored contiguously in memory in a
|
||||
// `std::vector` whose element type corresponds to the element type of
|
||||
// the tensor.
|
||||
union value_vector_union {
|
||||
std::vector<uint64_t> *u64;
|
||||
std::vector<int64_t> *i64;
|
||||
std::vector<uint32_t> *u32;
|
||||
std::vector<int32_t> *i32;
|
||||
std::vector<uint16_t> *u16;
|
||||
std::vector<int16_t> *i16;
|
||||
std::vector<uint8_t> *u8;
|
||||
std::vector<int8_t> *i8;
|
||||
};
|
||||
|
||||
// Function templates that would go into the class `TensorData`, but
|
||||
// which need to declared in namespace scope, since specializations of
|
||||
// templates on the return type cannot be done for member functions as
|
||||
// per the C++ standard
|
||||
template <typename T> T begin(union value_vector_union &vec);
|
||||
template <typename T> T end(union value_vector_union &vec);
|
||||
template <typename T> T cbegin(union value_vector_union &vec);
|
||||
template <typename T> T cend(union value_vector_union &vec);
|
||||
template <typename T> T getElements(union value_vector_union &vec);
|
||||
template <typename T> T getConstElements(const union value_vector_union &vec);
|
||||
|
||||
template <typename T>
|
||||
T getElementValue(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
template <typename T>
|
||||
T &getElementReference(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
template <typename T>
|
||||
T *getElementPointer(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
|
||||
// Specializations for the above templates
|
||||
#define TENSORDATA_SPECIALIZE_FOR_ITERATOR(ELTY, SUFFIX) \
|
||||
template <> \
|
||||
inline std::vector<ELTY>::iterator begin(union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->begin(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::iterator end(union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->end(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::const_iterator cbegin( \
|
||||
union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->cbegin(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::const_iterator cend( \
|
||||
union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->cend(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY> &getElements(union value_vector_union &vec) { \
|
||||
return *vec.SUFFIX; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline const std::vector<ELTY> &getConstElements( \
|
||||
const union value_vector_union &vec) { \
|
||||
return *vec.SUFFIX; \
|
||||
}
|
||||
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint64_t, u64)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int64_t, i64)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint32_t, u32)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int32_t, i32)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint16_t, u16)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int16_t, i16)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint8_t, u8)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int8_t, i8)
|
||||
|
||||
#define TENSORDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \
|
||||
template <> \
|
||||
inline ELTY getElementValue(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return (*vec.SUFFIX)[idx]; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline ELTY &getElementReference(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return (*vec.SUFFIX)[idx]; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline ELTY *getElementPointer(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return &(*vec.SUFFIX)[idx]; \
|
||||
}
|
||||
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8)
|
||||
|
||||
} // namespace TensorData
|
||||
} // namespace detail
|
||||
|
||||
// Representation of a tensor with an arbitrary number of dimensions
|
||||
class TensorData {
|
||||
protected:
|
||||
detail::TensorData::value_vector_union values;
|
||||
ElementType elementType;
|
||||
std::vector<size_t> dimensions;
|
||||
size_t elementWidth;
|
||||
|
||||
/* Multi-dimensional, uninitialized, but preallocated tensor */
|
||||
void initPreallocated(llvm::ArrayRef<size_t> dimensions,
|
||||
ElementType elementType, size_t elementWidth,
|
||||
bool sign) {
|
||||
assert(checkElementTypeForWidthAndSign(elementType, elementWidth, sign) &&
|
||||
"Incoherent parameters for element type, width and sign");
|
||||
|
||||
assert(dimensions.size() != 0);
|
||||
|
||||
size_t n = getNumElements(dimensions);
|
||||
|
||||
switch (elementType) {
|
||||
case ElementType::u64:
|
||||
this->values.u64 = new std::vector<uint64_t>(n);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
this->values.i64 = new std::vector<int64_t>(n);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
this->values.u32 = new std::vector<uint32_t>(n);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
this->values.i32 = new std::vector<int32_t>(n);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
this->values.u16 = new std::vector<uint16_t>(n);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
this->values.i16 = new std::vector<int16_t>(n);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
this->values.u8 = new std::vector<uint8_t>(n);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
this->values.i8 = new std::vector<int8_t>(n);
|
||||
break;
|
||||
}
|
||||
|
||||
this->dimensions.resize(dimensions.size());
|
||||
this->elementWidth = elementWidth;
|
||||
this->elementType = elementType;
|
||||
std::copy(dimensions.begin(), dimensions.end(), this->dimensions.begin());
|
||||
}
|
||||
|
||||
// Creates a vector<size_t> from an ArrayRef<T>
|
||||
template <typename T>
|
||||
static std::vector<size_t> toDimSpec(llvm::ArrayRef<T> dims) {
|
||||
return std::vector<size_t>(dims.begin(), dims.end());
|
||||
}
|
||||
|
||||
public:
|
||||
// Returns the total number of elements of a tensor with the
|
||||
// specified dimensions
|
||||
template <typename T> static size_t getNumElements(T dimensions) {
|
||||
size_t n = 1;
|
||||
for (auto dim : dimensions)
|
||||
n *= dim;
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
// Move constructor. Leaves `that` uninitialized.
|
||||
TensorData(TensorData &&that)
|
||||
: elementType(that.elementType), dimensions(std::move(that.dimensions)),
|
||||
elementWidth(that.elementWidth) {
|
||||
switch (that.elementType) {
|
||||
case ElementType::u64:
|
||||
this->values.u64 = that.values.u64;
|
||||
that.values.u64 = nullptr;
|
||||
break;
|
||||
case ElementType::i64:
|
||||
this->values.i64 = that.values.i64;
|
||||
that.values.i64 = nullptr;
|
||||
break;
|
||||
case ElementType::u32:
|
||||
this->values.u32 = that.values.u32;
|
||||
that.values.u32 = nullptr;
|
||||
break;
|
||||
case ElementType::i32:
|
||||
this->values.i32 = that.values.i32;
|
||||
that.values.i32 = nullptr;
|
||||
break;
|
||||
case ElementType::u16:
|
||||
this->values.u16 = that.values.u16;
|
||||
that.values.u16 = nullptr;
|
||||
break;
|
||||
case ElementType::i16:
|
||||
this->values.i16 = that.values.i16;
|
||||
that.values.i16 = nullptr;
|
||||
break;
|
||||
case ElementType::u8:
|
||||
this->values.u8 = that.values.u8;
|
||||
that.values.u8 = nullptr;
|
||||
break;
|
||||
case ElementType::i8:
|
||||
this->values.i8 = that.values.i8;
|
||||
that.values.i8 = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Constructor to build a multi-dimensional tensor with the
|
||||
// corresponding element type. All elements are initialized with the
|
||||
// default value of `0`.
|
||||
TensorData(llvm::ArrayRef<size_t> dimensions, ElementType elementType,
|
||||
size_t elementWidth) {
|
||||
initPreallocated(dimensions, elementType, elementWidth,
|
||||
getElementTypeSignedness(elementType));
|
||||
}
|
||||
|
||||
TensorData(llvm::ArrayRef<int64_t> dimensions, ElementType elementType,
|
||||
size_t elementWidth)
|
||||
: TensorData(toDimSpec(dimensions), elementType, elementWidth) {}
|
||||
|
||||
// Constructor to build a multi-dimensional tensor with the element
|
||||
// type corresponding to `elementWidth` and `sign`. All elements are
|
||||
// initialized with the default value of `0`.
|
||||
TensorData(llvm::ArrayRef<size_t> dimensions, size_t elementWidth, bool sign)
|
||||
: TensorData(dimensions,
|
||||
getElementTypeFromWidthAndSign(elementWidth, sign),
|
||||
elementWidth) {}
|
||||
|
||||
TensorData(llvm::ArrayRef<int64_t> dimensions, size_t elementWidth, bool sign)
|
||||
: TensorData(toDimSpec(dimensions), elementWidth, sign) {}
|
||||
|
||||
#define DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(ELTY, SUFFIX) \
|
||||
/* Multi-dimensional, initialized tensor, values copied from */ \
|
||||
/* `values` */ \
|
||||
TensorData(llvm::ArrayRef<ELTY> values, llvm::ArrayRef<size_t> dimensions, \
|
||||
size_t elementWidth) \
|
||||
: dimensions(dimensions.begin(), dimensions.end()) { \
|
||||
assert(checkElementTypeForWidthAndSign(ElementType::SUFFIX, elementWidth, \
|
||||
std::is_signed<ELTY>()) && \
|
||||
"wrong element type for width"); \
|
||||
assert(dimensions.size() != 0); \
|
||||
size_t n = getNumElements(dimensions); \
|
||||
this->values.SUFFIX = new std::vector<ELTY>(n); \
|
||||
this->elementType = ElementType::SUFFIX; \
|
||||
this->bulkAssign(values); \
|
||||
} \
|
||||
\
|
||||
/* One-dimensional, initialized tensor. Values are copied from */ \
|
||||
/* `values` */ \
|
||||
TensorData(llvm::ArrayRef<ELTY> values, size_t width) \
|
||||
: TensorData(values, llvm::SmallVector<size_t, 1>{values.size()}, \
|
||||
width) {}
|
||||
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint64_t, u64)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int64_t, i64)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint32_t, u32)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int32_t, i32)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint16_t, u16)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int16_t, i16)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint8_t, u8)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int8_t, i8)
|
||||
|
||||
~TensorData() {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
delete values.u64;
|
||||
break;
|
||||
case ElementType::i64:
|
||||
delete values.i64;
|
||||
break;
|
||||
case ElementType::u32:
|
||||
delete values.u32;
|
||||
break;
|
||||
case ElementType::i32:
|
||||
delete values.i32;
|
||||
break;
|
||||
case ElementType::u16:
|
||||
delete values.u16;
|
||||
break;
|
||||
case ElementType::i16:
|
||||
delete values.i16;
|
||||
break;
|
||||
case ElementType::u8:
|
||||
delete values.u8;
|
||||
break;
|
||||
case ElementType::i8:
|
||||
delete values.i8;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the total number of elements of the tensor
|
||||
size_t length() const { return getNumElements(this->dimensions); }
|
||||
|
||||
// Returns a vector with the size for each dimension of the tensor
|
||||
const std::vector<size_t> &getDimensions() const { return this->dimensions; }
|
||||
|
||||
template <typename T> const std::vector<T> getDimensionsAs() const {
|
||||
return std::vector<T>(this->dimensions.begin(), this->dimensions.end());
|
||||
}
|
||||
|
||||
// Returns the number of dimensions
|
||||
size_t getRank() const { return this->dimensions.size(); }
|
||||
|
||||
// Multi-dimensional access to a tensor element
|
||||
template <typename T> T &operator[](llvm::ArrayRef<int64_t> index) {
|
||||
// Number of dimensions must match
|
||||
assert(index.size() == dimensions.size());
|
||||
|
||||
int64_t offset = 0;
|
||||
int64_t multiplier = 1;
|
||||
for (int64_t i = index.size() - 1; i > 0; i--) {
|
||||
offset += index[i] * multiplier;
|
||||
multiplier *= this->dimensions[i];
|
||||
}
|
||||
|
||||
return detail::TensorData::getElementReference<T>(values, offset,
|
||||
elementType);
|
||||
}
|
||||
|
||||
// Iterator pointing to the first element of a flat representation
|
||||
// of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator begin() {
|
||||
return detail::TensorData::begin<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Iterator pointing past the last element of a flat representation
|
||||
// of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator end() {
|
||||
return detail::TensorData::end<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Const iterator pointing to the first element of a flat
|
||||
// representation of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator cbegin() {
|
||||
return detail::TensorData::cbegin<typename std::vector<T>::iterator>(
|
||||
values);
|
||||
}
|
||||
|
||||
// Const iterator pointing past the last element of a flat
|
||||
// representation of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator cend() {
|
||||
return detail::TensorData::cend<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Flat representation of the const tensor
|
||||
template <typename T> const std::vector<T> &getElements() const {
|
||||
return detail::TensorData::getConstElements<const std::vector<T> &>(values);
|
||||
}
|
||||
|
||||
// Flat representation of the tensor
|
||||
template <typename T> const std::vector<T> &getElements() {
|
||||
return detail::TensorData::getElements<std::vector<T> &>(values);
|
||||
}
|
||||
|
||||
// Returns the `index`-th value of a flat representation of the tensor
|
||||
template <typename T> T getElementValue(size_t index) {
|
||||
return detail::TensorData::getElementValue<T>(values, index, elementType);
|
||||
}
|
||||
|
||||
// Returns a reference to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
template <typename T> T &getElementReference(size_t index) {
|
||||
return detail::TensorData::getElementReference<T>(values, index,
|
||||
elementType);
|
||||
}
|
||||
|
||||
// Returns a pointer to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
template <typename T> T *getElementPointer(size_t index) {
|
||||
return detail::TensorData::getElementPointer<T>(values, index, elementType);
|
||||
}
|
||||
|
||||
// Returns a pointer to the `index`-th value of a flat
|
||||
// representation of the tensor (const version)
|
||||
template <typename T> const T *getElementPointer(size_t index) const {
|
||||
return detail::TensorData::getElementPointer<T>(values, index, elementType);
|
||||
}
|
||||
|
||||
// Returns a void pointer to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
void *getOpaqueElementPointer(size_t index) {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint64_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i64:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int64_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u32:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint32_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i32:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int32_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u16:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint16_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i16:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int16_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u8:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint8_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i8:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int8_t>(values, index,
|
||||
elementType));
|
||||
}
|
||||
|
||||
assert(false && "Unknown element type");
|
||||
}
|
||||
|
||||
// Returns the element type of the tensor
|
||||
ElementType getElementType() const { return this->elementType; }
|
||||
|
||||
// Returns the actual width in bits of a data element (i.e., the
|
||||
// width specified upon construction and not the storage width of an
|
||||
// element)
|
||||
size_t getElementWidth() const { return this->elementWidth; }
|
||||
|
||||
// Returns the size of a tensor element in bytes (i.e., the storage width in
|
||||
// bytes)
|
||||
size_t getElementSize() const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
case ElementType::i64:
|
||||
return 8;
|
||||
case ElementType::u32:
|
||||
case ElementType::i32:
|
||||
return 4;
|
||||
case ElementType::u16:
|
||||
case ElementType::i16:
|
||||
return 2;
|
||||
case ElementType::u8:
|
||||
case ElementType::i8:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns `true` if elements are signed, otherwise `false`
|
||||
bool getElementSignedness() const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
case ElementType::u32:
|
||||
case ElementType::u16:
|
||||
case ElementType::u8:
|
||||
return false;
|
||||
case ElementType::i64:
|
||||
case ElementType::i32:
|
||||
case ElementType::i16:
|
||||
case ElementType::i8:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the total number of elements of the tensor
|
||||
size_t getNumElements() const { return getNumElements(this->dimensions); }
|
||||
|
||||
// Copy all elements from `values` to the tensor. Note that this
|
||||
// does not append values to the tensor, but overwrites existing
|
||||
// values.
|
||||
template <typename T> void bulkAssign(llvm::ArrayRef<T> values) {
|
||||
assert(values.size() <= this->getNumElements());
|
||||
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
std::copy(values.begin(), values.end(), this->values.u64->begin());
|
||||
break;
|
||||
case ElementType::i64:
|
||||
std::copy(values.begin(), values.end(), this->values.i64->begin());
|
||||
break;
|
||||
case ElementType::u32:
|
||||
std::copy(values.begin(), values.end(), this->values.u32->begin());
|
||||
break;
|
||||
case ElementType::i32:
|
||||
std::copy(values.begin(), values.end(), this->values.i32->begin());
|
||||
break;
|
||||
case ElementType::u16:
|
||||
std::copy(values.begin(), values.end(), this->values.u16->begin());
|
||||
break;
|
||||
case ElementType::i16:
|
||||
std::copy(values.begin(), values.end(), this->values.i16->begin());
|
||||
break;
|
||||
case ElementType::u8:
|
||||
std::copy(values.begin(), values.end(), this->values.u8->begin());
|
||||
break;
|
||||
case ElementType::i8:
|
||||
std::copy(values.begin(), values.end(), this->values.i8->begin());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Copies all elements of a flat representation of the tensor to the
|
||||
// positions starting with the iterator `start`.
|
||||
template <typename IT> void copy(IT start) const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
std::copy(this->values.u64->cbegin(), this->values.u64->cend(), start);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
std::copy(this->values.i64->cbegin(), this->values.i64->cend(), start);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
std::copy(this->values.u32->cbegin(), this->values.u32->cend(), start);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
std::copy(this->values.i32->cbegin(), this->values.i32->cend(), start);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
std::copy(this->values.u16->cbegin(), this->values.u16->cend(), start);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
std::copy(this->values.i16->cbegin(), this->values.i16->cend(), start);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
std::copy(this->values.u8->cbegin(), this->values.u8->cend(), start);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
std::copy(this->values.i8->cbegin(), this->values.i8->cend(), start);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a flat representation of the tensor with elements
|
||||
// converted to the type `T`
|
||||
template <typename T> std::vector<T> asFlatVector() const {
|
||||
std::vector<T> ret(getNumElements());
|
||||
this->copy(ret.begin());
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Returns a void pointer to the first element of a flat
|
||||
// representation of the tensor
|
||||
void *getValuesAsOpaquePointer() const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
return static_cast<void *>(values.u64->data());
|
||||
case ElementType::i64:
|
||||
return static_cast<void *>(values.i64->data());
|
||||
case ElementType::u32:
|
||||
return static_cast<void *>(values.u32->data());
|
||||
case ElementType::i32:
|
||||
return static_cast<void *>(values.i32->data());
|
||||
case ElementType::u16:
|
||||
return static_cast<void *>(values.u16->data());
|
||||
case ElementType::i16:
|
||||
return static_cast<void *>(values.i16->data());
|
||||
case ElementType::u8:
|
||||
return static_cast<void *>(values.u8->data());
|
||||
case ElementType::i8:
|
||||
return static_cast<void *>(values.i8->data());
|
||||
}
|
||||
|
||||
assert(false && "Unhandled element type");
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
namespace ScalarData {
|
||||
// Union representing a single scalar value
|
||||
union scalar_union {
|
||||
uint64_t u64;
|
||||
int64_t i64;
|
||||
uint32_t u32;
|
||||
int32_t i32;
|
||||
uint16_t u16;
|
||||
int16_t i16;
|
||||
uint8_t u8;
|
||||
int8_t i8;
|
||||
};
|
||||
|
||||
// Template + specializations that should be in ScalarData, but which need to be
|
||||
// in namespace scope
|
||||
template <typename T> T getValue(const union scalar_union &u, ElementType type);
|
||||
|
||||
#define SCALARDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \
|
||||
template <> \
|
||||
inline ELTY getValue(const union scalar_union &u, ElementType type) { \
|
||||
assert(type == ElementType::SUFFIX); \
|
||||
return u.SUFFIX; \
|
||||
}
|
||||
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8)
|
||||
SCALARDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8)
|
||||
|
||||
} // namespace ScalarData
|
||||
} // namespace detail
|
||||
|
||||
// Class representing a single scalar value
|
||||
class ScalarData {
|
||||
public:
|
||||
ScalarData(const ScalarData &s)
|
||||
: type(s.type), value(s.value), width(s.width) {}
|
||||
|
||||
// Construction with a specific type and an actual width, but with a value
|
||||
// provided in a generic `uint64_t`
|
||||
ScalarData(uint64_t value, ElementType type, size_t width)
|
||||
: type(type), width(width) {
|
||||
assert(width <= getElementTypeWidth(type));
|
||||
|
||||
switch (type) {
|
||||
case ElementType::u64:
|
||||
this->value.u64 = value;
|
||||
break;
|
||||
case ElementType::i64:
|
||||
this->value.i64 = value;
|
||||
break;
|
||||
case ElementType::u32:
|
||||
this->value.u32 = value;
|
||||
break;
|
||||
case ElementType::i32:
|
||||
this->value.i32 = value;
|
||||
break;
|
||||
case ElementType::u16:
|
||||
this->value.u16 = value;
|
||||
break;
|
||||
case ElementType::i16:
|
||||
this->value.i16 = value;
|
||||
break;
|
||||
case ElementType::u8:
|
||||
this->value.u8 = value;
|
||||
break;
|
||||
case ElementType::i8:
|
||||
this->value.i8 = value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Construction with a specific type determined by `sign` and
|
||||
// `width`, but value provided in a generic `uint64_t`
|
||||
ScalarData(uint64_t value, bool sign, size_t width)
|
||||
: ScalarData(value, getElementTypeFromWidthAndSign(width, sign), width) {}
|
||||
|
||||
#define DEF_SCALAR_DATA_CONSTRUCTOR(ELTY, SUFFIX) \
|
||||
ScalarData(ELTY value) \
|
||||
: type(ElementType::SUFFIX), \
|
||||
width(getElementTypeWidth(ElementType::SUFFIX)) { \
|
||||
this->value.SUFFIX = value; \
|
||||
}
|
||||
|
||||
// Construction from specific value type
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(uint64_t, u64)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(int64_t, i64)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(uint32_t, u32)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(int32_t, i32)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(uint16_t, u16)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(int16_t, i16)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(uint8_t, u8)
|
||||
DEF_SCALAR_DATA_CONSTRUCTOR(int8_t, i8)
|
||||
|
||||
template <typename T> T getValue() const {
|
||||
return detail::ScalarData::getValue<T>(value, type);
|
||||
}
|
||||
|
||||
// Retrieves the value as a generic `uint64_t`
|
||||
uint64_t getValueAsU64() const {
|
||||
size_t width = getElementTypeWidth(type);
|
||||
if (width == 64)
|
||||
return value.u64;
|
||||
uint64_t mask = ((uint64_t)1 << width) - 1;
|
||||
uint64_t val = value.u64 & mask;
|
||||
return val;
|
||||
}
|
||||
|
||||
ElementType getType() const { return type; }
|
||||
size_t getWidth() const { return width; }
|
||||
|
||||
protected:
|
||||
ElementType type;
|
||||
union detail::ScalarData::scalar_union value;
|
||||
size_t width;
|
||||
};
|
||||
|
||||
// Variant for TensorData and ScalarData
|
||||
class ScalarOrTensorData {
|
||||
protected:
|
||||
std::unique_ptr<ScalarData> scalar;
|
||||
std::unique_ptr<TensorData> tensor;
|
||||
|
||||
public:
|
||||
ScalarOrTensorData(const ScalarOrTensorData &td) = delete;
|
||||
ScalarOrTensorData(ScalarOrTensorData &&td)
|
||||
: scalar(std::move(td.scalar)), tensor(std::move(td.tensor)) {}
|
||||
|
||||
ScalarOrTensorData(TensorData &&td)
|
||||
: scalar(nullptr), tensor(std::make_unique<TensorData>(std::move(td))) {}
|
||||
|
||||
ScalarOrTensorData(const ScalarData &s)
|
||||
: scalar(std::make_unique<ScalarData>(s)), tensor(nullptr) {}
|
||||
|
||||
bool isTensor() const { return tensor != nullptr; }
|
||||
bool isScalar() const { return scalar != nullptr; }
|
||||
|
||||
ScalarData &getScalar() {
|
||||
assert(scalar != nullptr &&
|
||||
"Attempt to get a scalar value from variant that is a tensor");
|
||||
return *scalar;
|
||||
}
|
||||
|
||||
const ScalarData &getScalar() const {
|
||||
assert(scalar != nullptr &&
|
||||
"Attempt to get a scalar value from variant that is a tensor");
|
||||
return *scalar;
|
||||
}
|
||||
|
||||
TensorData &getTensor() {
|
||||
assert(tensor != nullptr &&
|
||||
"Attempt to get a tensor value from variant that is a scalar");
|
||||
return *tensor;
|
||||
}
|
||||
|
||||
const TensorData &getTensor() const {
|
||||
assert(tensor != nullptr &&
|
||||
"Attempt to get a tensor value from variant that is a scalar");
|
||||
return *tensor;
|
||||
}
|
||||
};
|
||||
|
||||
struct SharedScalarOrTensorData {
|
||||
std::shared_ptr<ScalarOrTensorData> inner;
|
||||
|
||||
SharedScalarOrTensorData(std::shared_ptr<ScalarOrTensorData> inner)
|
||||
: inner{inner} {}
|
||||
|
||||
SharedScalarOrTensorData(ScalarOrTensorData &&inner)
|
||||
: inner{std::make_shared<ScalarOrTensorData>(std::move(inner))} {}
|
||||
|
||||
ScalarOrTensorData &get() const { return *this->inner; }
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,239 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_VALUE_DECRYPTER_H
|
||||
#define CONCRETELANG_CLIENTLIB_VALUE_DECRYPTER_H
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class ValueDecrypterInterface {
|
||||
protected:
|
||||
virtual outcome::checked<uint64_t, StringError>
|
||||
decryptValue(size_t argPos, uint64_t *ciphertext) = 0;
|
||||
/// Size of the low-level ciphertext, taking into account the CRT if used
|
||||
virtual int64_t ciphertextSize(CircuitGate &gate) = 0;
|
||||
/// Output gate at position `argPos`
|
||||
virtual outcome::checked<CircuitGate, StringError>
|
||||
outputGate(size_t argPos) = 0;
|
||||
/// Whether the value decrypter is simulating encryption
|
||||
virtual bool isSimulated() = 0;
|
||||
|
||||
/// Whether argument at pos `argPos` is encrypted or not
|
||||
virtual outcome::checked<bool, StringError> isEncrypted(size_t argPos) {
|
||||
OUTCOME_TRY(auto gate, outputGate(argPos));
|
||||
return gate.isEncrypted();
|
||||
}
|
||||
|
||||
public:
|
||||
virtual ~ValueDecrypterInterface() = default;
|
||||
|
||||
/// @brief Transforms a FHE value into a clear scalar value
|
||||
/// @tparam T The type of the clear scalar value
|
||||
/// @param value The value to decrypt
|
||||
/// @param pos The position of the argument
|
||||
/// @return Either the decrypted value or an error if the gate doesn't match
|
||||
/// the expected result.
|
||||
template <typename T>
|
||||
outcome::checked<T, StringError> decrypt(ScalarOrTensorData &value,
|
||||
size_t pos) {
|
||||
OUTCOME_TRY(auto encrypted, isEncrypted(pos));
|
||||
if (!encrypted)
|
||||
return value.getScalar().getValue<T>();
|
||||
|
||||
if (isSimulated()) {
|
||||
OUTCOME_TRY(auto gate, outputGate(pos));
|
||||
auto crtVec = gate.encryption->encoding.crt;
|
||||
if (crtVec.empty()) {
|
||||
// value is a scalar
|
||||
auto ciphertext = value.getScalar().getValue<uint64_t>();
|
||||
OUTCOME_TRY(auto decrypted, decryptValue(pos, &ciphertext));
|
||||
return (T)decrypted;
|
||||
} else {
|
||||
// value is a tensor (crt)
|
||||
auto &buffer = value.getTensor();
|
||||
auto ciphertext = buffer.getOpaqueElementPointer(0);
|
||||
OUTCOME_TRY(
|
||||
auto decrypted,
|
||||
decryptValue(pos, reinterpret_cast<uint64_t *>(ciphertext)));
|
||||
return (T)decrypted;
|
||||
}
|
||||
}
|
||||
|
||||
auto &buffer = value.getTensor();
|
||||
|
||||
auto ciphertext = buffer.getOpaqueElementPointer(0);
|
||||
|
||||
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
|
||||
// FIXME: this may break alignment restrictions on some
|
||||
// architectures
|
||||
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
|
||||
OUTCOME_TRY(auto decrypted, decryptValue(pos, ciphertextu64));
|
||||
|
||||
return (T)decrypted;
|
||||
}
|
||||
|
||||
/// @brief Transforms a FHE value into a vector of clear value
|
||||
/// @tparam T The type of the clear scalar value
|
||||
/// @param value The value to decrypt
|
||||
/// @param pos The position of the argument
|
||||
/// @return Either the decrypted value or an error if the gate doesn't match
|
||||
/// the expected result.
|
||||
template <typename T>
|
||||
outcome::checked<std::vector<T>, StringError>
|
||||
decryptTensor(ScalarOrTensorData &value, size_t pos) {
|
||||
OUTCOME_TRY(auto encrypted, isEncrypted(pos));
|
||||
if (!encrypted)
|
||||
return value.getTensor().asFlatVector<T>();
|
||||
|
||||
auto &buffer = value.getTensor();
|
||||
OUTCOME_TRY(auto gate, outputGate(pos));
|
||||
auto lweSize = ciphertextSize(gate);
|
||||
|
||||
std::vector<T> decryptedValues(buffer.length() / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize);
|
||||
|
||||
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
|
||||
// FIXME: this may break alignment restrictions on some
|
||||
// architectures
|
||||
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
|
||||
OUTCOME_TRY(auto decrypted, decryptValue(pos, ciphertextu64));
|
||||
decryptedValues[i] = decrypted;
|
||||
}
|
||||
return decryptedValues;
|
||||
}
|
||||
|
||||
/// Return the shape of the clear tensor of a result.
|
||||
outcome::checked<std::vector<int64_t>, StringError> getShape(size_t pos) {
|
||||
OUTCOME_TRY(auto gate, outputGate(pos));
|
||||
return gate.shape.dimensions;
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief allows to transform a serializable value into a clear value
|
||||
class ValueDecrypter : public ValueDecrypterInterface {
|
||||
public:
|
||||
ValueDecrypter(KeySet &keySet, ClientParameters clientParameters)
|
||||
: _keySet(keySet), _clientParameters(clientParameters) {}
|
||||
|
||||
protected:
|
||||
outcome::checked<uint64_t, StringError>
|
||||
decryptValue(size_t argPos, uint64_t *ciphertext) override {
|
||||
uint64_t decrypted;
|
||||
OUTCOME_TRYV(_keySet.decrypt_lwe(0, ciphertext, decrypted));
|
||||
return decrypted;
|
||||
};
|
||||
|
||||
bool isSimulated() override { return false; }
|
||||
|
||||
outcome::checked<CircuitGate, StringError>
|
||||
outputGate(size_t argPos) override {
|
||||
return _clientParameters.ouput(argPos);
|
||||
}
|
||||
|
||||
int64_t ciphertextSize(CircuitGate &gate) override {
|
||||
return _clientParameters.lweBufferSize(gate);
|
||||
}
|
||||
|
||||
private:
|
||||
KeySet &_keySet;
|
||||
ClientParameters _clientParameters;
|
||||
};
|
||||
|
||||
class SimulatedValueDecrypter : public ValueDecrypterInterface {
|
||||
public:
|
||||
SimulatedValueDecrypter(ClientParameters clientParameters)
|
||||
: _clientParameters(clientParameters) {}
|
||||
|
||||
protected:
|
||||
// TODO: a lot of this logic can be factorized when moving
|
||||
// `KeySet::decrypt_lwe` into the LWE ValueDecyrpter
|
||||
outcome::checked<uint64_t, StringError>
|
||||
decryptValue(size_t argPos, uint64_t *ciphertext) override {
|
||||
uint64_t output;
|
||||
OUTCOME_TRY(auto gate, outputGate(argPos));
|
||||
auto encoding = gate.encryption->encoding;
|
||||
auto precision = encoding.precision;
|
||||
auto crtVec = gate.encryption->encoding.crt;
|
||||
if (crtVec.empty()) {
|
||||
output = *ciphertext;
|
||||
output >>= (64 - precision - 2);
|
||||
auto carry = output % 2;
|
||||
uint64_t mod = (((uint64_t)1) << (precision + 1));
|
||||
output = ((output >> 1) + carry) % mod;
|
||||
// Further decode signed integers.
|
||||
if (encoding.isSigned) {
|
||||
uint64_t maxPos = (((uint64_t)1) << (precision - 1));
|
||||
if (output >= maxPos) { // The output is actually negative.
|
||||
// Set the preceding bits to zero
|
||||
output |= UINT64_MAX << precision;
|
||||
// This makes sure when the value is cast to int64, it has the correct
|
||||
// value
|
||||
};
|
||||
}
|
||||
} else {
|
||||
// Decrypt and decode remainders
|
||||
std::vector<int64_t> remainders;
|
||||
for (auto modulus : crtVec) {
|
||||
output = *ciphertext;
|
||||
auto plaintext = crt::decode(output, modulus);
|
||||
remainders.push_back(plaintext);
|
||||
// each ciphertext is a scalar
|
||||
ciphertext = ciphertext + 1;
|
||||
}
|
||||
output = crt::iCrt(crtVec, remainders);
|
||||
// Further decode signed integers
|
||||
if (encoding.isSigned) {
|
||||
uint64_t maxPos = 1;
|
||||
for (auto prime : crtVec) {
|
||||
maxPos *= prime;
|
||||
}
|
||||
maxPos /= 2;
|
||||
if (output >= maxPos) {
|
||||
output -= maxPos * 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
bool isSimulated() override { return true; }
|
||||
|
||||
outcome::checked<CircuitGate, StringError>
|
||||
outputGate(size_t argPos) override {
|
||||
return _clientParameters.ouput(argPos);
|
||||
}
|
||||
|
||||
/// @brief Ciphertext size in simulation
|
||||
/// When using CRT encoding, it's the number of blocks, otherwise, it's just 1
|
||||
/// scalar
|
||||
/// @param gate
|
||||
/// @return number of scalars to represent one input
|
||||
int64_t ciphertextSize(CircuitGate &gate) override {
|
||||
// ciphertext in simulation are only scalars
|
||||
assert(gate.encryption.has_value());
|
||||
auto crtSize = gate.encryption->encoding.crt.size();
|
||||
return crtSize == 0 ? 1 : crtSize;
|
||||
}
|
||||
|
||||
private:
|
||||
ClientParameters _clientParameters;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,283 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_VALUE_EXPORTER_H
|
||||
#define CONCRETELANG_CLIENTLIB_VALUE_EXPORTER_H
|
||||
|
||||
#include <ostream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/simulation.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class ValueExporterInterface {
|
||||
protected:
|
||||
virtual outcome::checked<void, StringError> encryptValue(CircuitGate &gate,
|
||||
size_t argPos,
|
||||
uint64_t *ciphertext,
|
||||
uint64_t input) = 0;
|
||||
/// Encrypt and export a 64bits integer to a serializale value
|
||||
virtual outcome::checked<ScalarOrTensorData, StringError>
|
||||
exportEncryptValue(uint64_t arg, CircuitGate &gate, size_t argPos) = 0;
|
||||
/// Shape of the low-level buffer
|
||||
virtual std::vector<int64_t> bufferShape(CircuitGate &gate) = 0;
|
||||
/// Size of the low-level ciphertext, taking into account the CRT if used
|
||||
virtual int64_t ciphertextSize(CircuitGate &gate) = 0;
|
||||
/// Input gate at position `argPos`
|
||||
virtual outcome::checked<CircuitGate, StringError>
|
||||
inputGate(size_t argPos) = 0;
|
||||
|
||||
public:
|
||||
virtual ~ValueExporterInterface() = default;
|
||||
|
||||
/// @brief Export a scalar 64 bits integer to a concreteprocol::Value
|
||||
/// @param arg An 64 bits integer
|
||||
/// @param argPos The position of the argument to export
|
||||
/// @return Either the exported value ready to be sent to the server or an
|
||||
/// error if the gate doesn't match the expected argument.
|
||||
outcome::checked<ScalarOrTensorData, StringError> exportValue(uint64_t arg,
|
||||
size_t argPos) {
|
||||
OUTCOME_TRY(auto gate, inputGate(argPos));
|
||||
if (gate.shape.size != 0) {
|
||||
return StringError("argument #") << argPos << " is not a scalar";
|
||||
}
|
||||
if (gate.encryption.has_value()) {
|
||||
return exportEncryptValue(arg, gate, argPos);
|
||||
}
|
||||
return exportClearValue(arg);
|
||||
}
|
||||
|
||||
/// @brief Export a tensor like buffer of values to a serializable value
|
||||
/// @tparam T The type of values hold by the buffer
|
||||
/// @param arg A pointer to a memory area where the values are stored
|
||||
/// @param shape The shape of the tensor
|
||||
/// @param argPos The position of the argument to export
|
||||
/// @return Either the exported value ready to be sent to the server or an
|
||||
/// error if the gate doesn't match the expected argument.
|
||||
template <typename T>
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
exportValue(const T *arg, llvm::ArrayRef<int64_t> shape, size_t argPos) {
|
||||
OUTCOME_TRY(auto gate, inputGate(argPos));
|
||||
OUTCOME_TRYV(checkShape(shape, gate.shape, argPos));
|
||||
if (gate.encryption.has_value()) {
|
||||
return exportEncryptTensor(arg, shape, gate, argPos);
|
||||
}
|
||||
return exportClearTensor(arg, shape, gate);
|
||||
}
|
||||
|
||||
protected:
|
||||
/// Export a 64bits integer to a serializable value
|
||||
virtual outcome::checked<ScalarOrTensorData, StringError>
|
||||
exportClearValue(uint64_t arg) {
|
||||
return ScalarData(arg);
|
||||
}
|
||||
|
||||
/// Export a tensor like buffer to a serializable value
|
||||
template <typename T>
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
exportClearTensor(const T *arg, llvm::ArrayRef<int64_t> shape,
|
||||
CircuitGate &gate) {
|
||||
auto bitsPerValue = bitWidthAsWord(gate.shape.width);
|
||||
auto sizes = bufferShape(gate);
|
||||
TensorData td(sizes, bitsPerValue, gate.shape.sign);
|
||||
llvm::ArrayRef<T> values(arg, TensorData::getNumElements(sizes));
|
||||
td.bulkAssign(values);
|
||||
return std::move(td);
|
||||
}
|
||||
|
||||
/// Export and encrypt a tensor like buffer to a serializable value
|
||||
template <typename T>
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
exportEncryptTensor(const T *arg, llvm::ArrayRef<int64_t> shape,
|
||||
CircuitGate &gate, size_t argPos) {
|
||||
// Create and allocate the TensorData that will holds encrypted values
|
||||
auto sizes = bufferShape(gate);
|
||||
TensorData td(sizes, EncryptedScalarElementType,
|
||||
EncryptedScalarElementWidth);
|
||||
|
||||
// Iterate over values and encrypt at the right place the value
|
||||
auto lweSize = ciphertextSize(gate);
|
||||
for (size_t i = 0, offset = 0; i < gate.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
OUTCOME_TRYV(encryptValue(
|
||||
gate, argPos, td.getElementPointer<uint64_t>(offset), arg[i]));
|
||||
}
|
||||
return std::move(td);
|
||||
}
|
||||
|
||||
private:
|
||||
static outcome::checked<void, StringError>
|
||||
checkShape(llvm::ArrayRef<int64_t> shape, CircuitGateShape expected,
|
||||
size_t argPos) {
|
||||
// Check the shape of tensor
|
||||
if (expected.dimensions.empty()) {
|
||||
return StringError("argument #") << argPos << "is not a tensor";
|
||||
}
|
||||
if (shape.size() != expected.dimensions.size()) {
|
||||
return StringError("argument #")
|
||||
<< argPos << "has not the expected number of dimension, got "
|
||||
<< shape.size() << " expected " << expected.dimensions.size();
|
||||
}
|
||||
|
||||
// Check shape
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
if (shape[i] != expected.dimensions[i]) {
|
||||
return StringError("argument #")
|
||||
<< argPos << " has not the expected dimension #" << i
|
||||
<< " , got " << shape[i] << " expected "
|
||||
<< expected.dimensions[i];
|
||||
}
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief The ArgumentsExporter allows to transform clear
|
||||
/// arguments to the one expected by a server lambda.
|
||||
class ValueExporter : public ValueExporterInterface {
|
||||
|
||||
public:
|
||||
/// @brief
|
||||
/// @param keySet
|
||||
/// @param clientParameters
|
||||
// TODO: Get rid of the reference here could make troubles (see for KeySet
|
||||
// copy constructor or shared pointers)
|
||||
ValueExporter(KeySet &keySet, ClientParameters clientParameters)
|
||||
: _keySet(keySet), _clientParameters(clientParameters) {}
|
||||
|
||||
protected:
|
||||
outcome::checked<void, StringError> encryptValue(CircuitGate &gate,
|
||||
size_t argPos,
|
||||
uint64_t *ciphertext,
|
||||
uint64_t input) override {
|
||||
return _keySet.encrypt_lwe(argPos, ciphertext, input);
|
||||
}
|
||||
|
||||
outcome::checked<CircuitGate, StringError> inputGate(size_t argPos) override {
|
||||
return _clientParameters.input(argPos);
|
||||
}
|
||||
|
||||
std::vector<int64_t> bufferShape(CircuitGate &gate) override {
|
||||
return _clientParameters.bufferShape(gate);
|
||||
}
|
||||
|
||||
int64_t ciphertextSize(CircuitGate &gate) override {
|
||||
return _clientParameters.lweBufferSize(gate);
|
||||
}
|
||||
|
||||
/// Encrypt and export a 64bits integer to a serializale value
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
exportEncryptValue(uint64_t arg, CircuitGate &gate, size_t argPos) override {
|
||||
std::vector<int64_t> shape = _clientParameters.bufferShape(gate);
|
||||
|
||||
// Create and allocate the TensorData that will holds encrypted value
|
||||
TensorData td(shape, clientlib::EncryptedScalarElementType,
|
||||
clientlib::EncryptedScalarElementWidth);
|
||||
|
||||
// Encrypt the value
|
||||
OUTCOME_TRYV(
|
||||
encryptValue(gate, argPos, td.getElementPointer<uint64_t>(0), arg));
|
||||
return std::move(td);
|
||||
}
|
||||
|
||||
private:
|
||||
KeySet &_keySet;
|
||||
ClientParameters _clientParameters;
|
||||
};
|
||||
|
||||
/// @brief The SimulatedValueExporter allows to transform clear
|
||||
/// arguments to the one expected by a server lambda during simulation.
|
||||
class SimulatedValueExporter : public ValueExporterInterface {
|
||||
|
||||
public:
|
||||
SimulatedValueExporter(ClientParameters clientParameters)
|
||||
: _clientParameters(clientParameters), csprng(0) {}
|
||||
|
||||
protected:
|
||||
outcome::checked<void, StringError> encryptValue(CircuitGate &gate,
|
||||
size_t argPos,
|
||||
uint64_t *ciphertext,
|
||||
uint64_t input) override {
|
||||
auto crtVec = gate.encryption->encoding.crt;
|
||||
if (crtVec.empty()) {
|
||||
auto precision = gate.encryption->encoding.precision;
|
||||
OUTCOME_TRY(auto skParam, _clientParameters.lweSecretKeyParam(gate));
|
||||
auto lwe_dim = skParam.lweDimension();
|
||||
auto encoded_input = input << (64 - (precision + 1));
|
||||
*ciphertext =
|
||||
sim_encrypt_lwe_u64(encoded_input, lwe_dim, (void *)csprng.ptr);
|
||||
} else {
|
||||
// Put each decomposition into a new ciphertext
|
||||
auto product = concretelang::clientlib::crt::productOfModuli(crtVec);
|
||||
for (auto modulus : crtVec) {
|
||||
OUTCOME_TRY(auto skParam, _clientParameters.lweSecretKeyParam(gate));
|
||||
auto lwe_dim = skParam.lweDimension();
|
||||
auto plaintext = crt::encode(input, modulus, product);
|
||||
*ciphertext =
|
||||
sim_encrypt_lwe_u64(plaintext, lwe_dim, (void *)csprng.ptr);
|
||||
// each ciphertext is a scalar
|
||||
ciphertext = ciphertext + 1;
|
||||
}
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
/// Simulate encrypt and export a 64bits integer to a serializale value
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
exportEncryptValue(uint64_t arg, CircuitGate &gate, size_t argPos) override {
|
||||
auto crtVec = gate.encryption->encoding.crt;
|
||||
if (crtVec.empty()) {
|
||||
uint64_t encValue = 0;
|
||||
OUTCOME_TRYV(encryptValue(gate, argPos, &encValue, arg));
|
||||
return ScalarData(encValue);
|
||||
} else {
|
||||
TensorData td(bufferShape(gate), clientlib::EncryptedScalarElementType,
|
||||
clientlib::EncryptedScalarElementWidth);
|
||||
OUTCOME_TRYV(
|
||||
encryptValue(gate, argPos, td.getElementPointer<uint64_t>(0), arg));
|
||||
return std::move(td);
|
||||
}
|
||||
}
|
||||
|
||||
outcome::checked<CircuitGate, StringError> inputGate(size_t argPos) override {
|
||||
return _clientParameters.input(argPos);
|
||||
}
|
||||
|
||||
std::vector<int64_t> bufferShape(CircuitGate &gate) override {
|
||||
return _clientParameters.bufferShape(gate, true);
|
||||
}
|
||||
|
||||
/// @brief Ciphertext size in simulation
|
||||
/// When using CRT encoding, it's the number of blocks, otherwise, it's just 1
|
||||
/// scalar
|
||||
/// @param gate
|
||||
/// @return number of scalars to represent one input
|
||||
int64_t ciphertextSize(CircuitGate &gate) override {
|
||||
// ciphertext in simulation are only scalars
|
||||
assert(gate.encryption.has_value());
|
||||
auto crtSize = gate.encryption->encoding.crt.size();
|
||||
return crtSize == 0 ? 1 : crtSize;
|
||||
}
|
||||
|
||||
private:
|
||||
ClientParameters _clientParameters;
|
||||
concretelang::clientlib::ConcreteCSPRNG csprng;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,19 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_BITS_SIZE_H
|
||||
#define CONCRETELANG_COMMON_BITS_SIZE_H
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace concretelang {
|
||||
namespace common {
|
||||
|
||||
size_t bitWidthAsWord(size_t exactBitWidth);
|
||||
|
||||
}
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -3,14 +3,13 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_CLIENTLIB_CRT_H_
|
||||
#define CONCRETELANG_CLIENTLIB_CRT_H_
|
||||
#ifndef CONCRETELANG_COMMON_CRT_H_
|
||||
#define CONCRETELANG_COMMON_CRT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
namespace crt {
|
||||
|
||||
/// Compute the product of the moduli of the crt decomposition.
|
||||
@@ -40,7 +39,6 @@ uint64_t encode(int64_t plaintext, uint64_t modulus, uint64_t product);
|
||||
uint64_t decode(uint64_t val, uint64_t modulus);
|
||||
|
||||
} // namespace crt
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,514 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
//
|
||||
// NOTE:
|
||||
// -----
|
||||
// To limit the size of the refactoring, we chose to not propagate the new
|
||||
// client/server lib to the exterior APIs after it was finalized. This file only
|
||||
// serves as a compatibility layer for exterior (python/rust/c) apis, for the
|
||||
// time being.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_COMPAT
|
||||
#define CONCRETELANG_COMMON_COMPAT
|
||||
|
||||
#include "capnp/serialize-packed.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/ClientLib/ClientLib.h"
|
||||
#include "concretelang/Common/Keys.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Common/Values.h"
|
||||
#include "concretelang/ServerLib/ServerLib.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "kj/io.h"
|
||||
#include "kj/std/iostream.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
|
||||
using concretelang::clientlib::ClientCircuit;
|
||||
using concretelang::clientlib::ClientProgram;
|
||||
using concretelang::keysets::Keyset;
|
||||
using concretelang::keysets::KeysetCache;
|
||||
using concretelang::keysets::ServerKeyset;
|
||||
using concretelang::serverlib::ServerCircuit;
|
||||
using concretelang::serverlib::ServerProgram;
|
||||
using concretelang::values::TransportValue;
|
||||
using concretelang::values::Value;
|
||||
|
||||
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
|
||||
auto VARNAME = EXPECTED; \
|
||||
if (auto err = VARNAME.takeError()) { \
|
||||
throw std::runtime_error(llvm::toString(std::move(err))); \
|
||||
}
|
||||
|
||||
#define CONCAT(a, b) CONCAT_INNER(a, b)
|
||||
#define CONCAT_INNER(a, b) a##b
|
||||
|
||||
#define GET_OR_THROW_RESULT_(VARNAME, RESULT, MAYBE) \
|
||||
auto MAYBE = RESULT; \
|
||||
if (MAYBE.has_failure()) { \
|
||||
throw std::runtime_error(MAYBE.as_failure().error().mesg); \
|
||||
} \
|
||||
VARNAME = MAYBE.value();
|
||||
|
||||
#define GET_OR_THROW_RESULT(VARNAME, RESULT) \
|
||||
GET_OR_THROW_RESULT_(VARNAME, RESULT, CONCAT(maybe, __COUNTER__))
|
||||
|
||||
#define EXPECTED_TRY_(lhs, rhs, maybe) \
|
||||
auto maybe = rhs; \
|
||||
if (auto err = maybe.takeError()) { \
|
||||
return std::move(err); \
|
||||
} \
|
||||
lhs = *maybe;
|
||||
|
||||
#define EXPECTED_TRY(lhs, rhs) \
|
||||
EXPECTED_TRY_(lhs, rhs, CONCAT(maybe, __COUNTER__))
|
||||
|
||||
template <typename T> llvm::Expected<T> outcomeToExpected(Result<T> outcome) {
|
||||
if (outcome.has_failure()) {
|
||||
return mlir::concretelang::StreamStringError(
|
||||
outcome.as_failure().error().mesg);
|
||||
} else {
|
||||
return outcome.value();
|
||||
}
|
||||
}
|
||||
|
||||
// Every number sent by python through the API has a type `int64` that must be
|
||||
// turned into the proper type expected by the ArgTransformers. This allows to
|
||||
// get an extra transformer executed right before the ArgTransformer gets
|
||||
// called.
|
||||
std::function<Value(Value)>
|
||||
getPythonTypeTransformer(const Message<concreteprotocol::GateInfo> &info) {
|
||||
if (info.asReader().getTypeInfo().hasIndex()) {
|
||||
return [=](Value input) {
|
||||
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
|
||||
return Value{(Tensor<uint64_t>)tensorInput};
|
||||
};
|
||||
} else if (info.asReader().getTypeInfo().hasPlaintext()) {
|
||||
if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <=
|
||||
8) {
|
||||
return [=](Value input) {
|
||||
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
|
||||
return Value{(Tensor<uint8_t>)tensorInput};
|
||||
};
|
||||
}
|
||||
if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <=
|
||||
16) {
|
||||
return [=](Value input) {
|
||||
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
|
||||
return Value{(Tensor<uint16_t>)tensorInput};
|
||||
};
|
||||
}
|
||||
if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <=
|
||||
32) {
|
||||
return [=](Value input) {
|
||||
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
|
||||
return Value{(Tensor<uint32_t>)tensorInput};
|
||||
};
|
||||
}
|
||||
if (info.asReader().getTypeInfo().getPlaintext().getIntegerPrecision() <=
|
||||
64) {
|
||||
return [=](Value input) {
|
||||
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
|
||||
return Value{(Tensor<uint64_t>)tensorInput};
|
||||
};
|
||||
}
|
||||
assert(false);
|
||||
} else if (info.asReader().getTypeInfo().hasLweCiphertext()) {
|
||||
if (info.asReader()
|
||||
.getTypeInfo()
|
||||
.getLweCiphertext()
|
||||
.getEncoding()
|
||||
.hasInteger() &&
|
||||
info.asReader()
|
||||
.getTypeInfo()
|
||||
.getLweCiphertext()
|
||||
.getEncoding()
|
||||
.getInteger()
|
||||
.getIsSigned()) {
|
||||
return [=](Value input) { return input; };
|
||||
} else {
|
||||
return [=](Value input) {
|
||||
Tensor<int64_t> tensorInput = input.getTensor<int64_t>().value();
|
||||
return Value{(Tensor<uint64_t>)tensorInput};
|
||||
};
|
||||
}
|
||||
} else {
|
||||
assert(false);
|
||||
}
|
||||
};
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct ServerLambda {
|
||||
ServerCircuit circuit;
|
||||
bool isSimulation;
|
||||
};
|
||||
} // namespace serverlib
|
||||
|
||||
namespace clientlib {
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct LweSecretKeyParam {
|
||||
Message<concreteprotocol::LweSecretKeyInfo> info;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct BootstrapKeyParam {
|
||||
Message<concreteprotocol::LweBootstrapKeyInfo> info;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct KeyswitchKeyParam {
|
||||
Message<concreteprotocol::LweKeyswitchKeyInfo> info;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct PackingKeyswitchKeyParam {
|
||||
Message<concreteprotocol::PackingKeyswitchKeyInfo> info;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct Encoding {
|
||||
Message<concreteprotocol::EncodingInfo> circuit;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct EncryptionGate {
|
||||
Message<concreteprotocol::GateInfo> gateInfo;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct CircuitGate {
|
||||
Message<concreteprotocol::GateInfo> gateInfo;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct ValueExporter {
|
||||
ClientCircuit circuit;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct SimulatedValueExporter {
|
||||
ClientCircuit circuit;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct ValueDecrypter {
|
||||
ClientCircuit circuit;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct SimulatedValueDecrypter {
|
||||
ClientCircuit circuit;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct PublicArguments {
|
||||
std::vector<TransportValue> values;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct PublicResult {
|
||||
std::vector<TransportValue> values;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct SharedScalarOrTensorData {
|
||||
TransportValue value;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct ClientParameters {
|
||||
Message<concreteprotocol::ProgramInfo> programInfo;
|
||||
std::vector<LweSecretKeyParam> secretKeys;
|
||||
std::vector<BootstrapKeyParam> bootstrapKeys;
|
||||
std::vector<KeyswitchKeyParam> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKeyParam> packingKeyswitchKeys;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct EvaluationKeys {
|
||||
ServerKeyset keyset;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct KeySetCache {
|
||||
KeysetCache keysetCache;
|
||||
};
|
||||
|
||||
/// A transition structure that preserver the current API of the library
|
||||
/// support.
|
||||
struct KeySet {
|
||||
Keyset keyset;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
/// A transition structure that preserves the current API of the library
|
||||
/// support.
|
||||
struct LambdaArgument {
|
||||
::concretelang::values::Value value;
|
||||
};
|
||||
|
||||
/// LibraryCompilationResult is the result of a compilation to a library.
|
||||
struct LibraryCompilationResult {
|
||||
/// The output directory path where the compilation artifacts have been
|
||||
/// generated.
|
||||
std::string outputDirPath;
|
||||
std::string funcName;
|
||||
};
|
||||
|
||||
class LibrarySupport {
|
||||
|
||||
public:
|
||||
LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "",
|
||||
bool generateSharedLib = true, bool generateStaticLib = true,
|
||||
bool generateClientParameters = true,
|
||||
bool generateCompilationFeedback = true)
|
||||
: outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath),
|
||||
generateSharedLib(generateSharedLib),
|
||||
generateStaticLib(generateStaticLib),
|
||||
generateClientParameters(generateClientParameters),
|
||||
generateCompilationFeedback(generateCompilationFeedback) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, CompilationOptions options) {
|
||||
// Setup the compiler engine
|
||||
auto context = CompilationContext::createShared();
|
||||
concretelang::CompilerEngine engine(context);
|
||||
engine.setCompilationOptions(options);
|
||||
|
||||
// Compile to a library
|
||||
auto library =
|
||||
engine.compile(program, outputPath, runtimeLibraryPath,
|
||||
generateSharedLib, generateStaticLib,
|
||||
generateClientParameters, generateCompilationFeedback);
|
||||
if (auto err = library.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.mainFuncName.has_value()) {
|
||||
return StreamStringError("Need to have a funcname to compile library");
|
||||
}
|
||||
this->funcName = options.mainFuncName.value();
|
||||
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->outputDirPath = outputPath;
|
||||
result->funcName = *options.mainFuncName;
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compile(llvm::StringRef s, CompilationOptions options) {
|
||||
std::unique_ptr<llvm::MemoryBuffer> mb =
|
||||
llvm::MemoryBuffer::getMemBuffer(s);
|
||||
llvm::SourceMgr sm;
|
||||
sm.AddNewSourceBuffer(std::move(mb), llvm::SMLoc());
|
||||
return this->compile(sm, options);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compile(mlir::ModuleOp &program,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> &context,
|
||||
CompilationOptions options) {
|
||||
|
||||
// Setup the compiler engine
|
||||
concretelang::CompilerEngine engine(context);
|
||||
engine.setCompilationOptions(options);
|
||||
|
||||
// Compile to a library
|
||||
auto library =
|
||||
engine.compile(program, outputPath, runtimeLibraryPath,
|
||||
generateSharedLib, generateStaticLib,
|
||||
generateClientParameters, generateCompilationFeedback);
|
||||
if (auto err = library.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.mainFuncName.has_value()) {
|
||||
return StreamStringError("Need to have a funcname to compile library");
|
||||
}
|
||||
this->funcName = options.mainFuncName.value();
|
||||
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->outputDirPath = outputPath;
|
||||
result->funcName = *options.mainFuncName;
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
/// Load the server lambda from the compilation result.
|
||||
llvm::Expected<::concretelang::serverlib::ServerLambda>
|
||||
loadServerLambda(LibraryCompilationResult &result, bool useSimulation) {
|
||||
EXPECTED_TRY(auto programInfo, getProgramInfo());
|
||||
EXPECTED_TRY(ServerProgram serverProgram,
|
||||
outcomeToExpected(ServerProgram::load(programInfo.asReader(),
|
||||
getSharedLibPath(),
|
||||
useSimulation)));
|
||||
EXPECTED_TRY(
|
||||
ServerCircuit serverCircuit,
|
||||
outcomeToExpected(serverProgram.getServerCircuit(result.funcName)));
|
||||
return ::concretelang::serverlib::ServerLambda{serverCircuit,
|
||||
useSimulation};
|
||||
}
|
||||
|
||||
/// Load the client parameters from the compilation result.
|
||||
llvm::Expected<::concretelang::clientlib::ClientParameters>
|
||||
loadClientParameters(LibraryCompilationResult &result) {
|
||||
EXPECTED_TRY(auto programInfo, getProgramInfo());
|
||||
if (programInfo.asReader().getCircuits().size() > 1) {
|
||||
return StreamStringError("ClientLambda: Provided program info contains "
|
||||
"more than one circuit.");
|
||||
}
|
||||
if (programInfo.asReader().getCircuits()[0].getName() != result.funcName) {
|
||||
return StreamStringError("Unexpected circuit name in program info");
|
||||
}
|
||||
auto secretKeys =
|
||||
std::vector<::concretelang::clientlib::LweSecretKeyParam>();
|
||||
for (auto key : programInfo.asReader().getKeyset().getLweSecretKeys()) {
|
||||
secretKeys.push_back(::concretelang::clientlib::LweSecretKeyParam{key});
|
||||
}
|
||||
auto boostrapKeys =
|
||||
std::vector<::concretelang::clientlib::BootstrapKeyParam>();
|
||||
for (auto key : programInfo.asReader().getKeyset().getLweBootstrapKeys()) {
|
||||
boostrapKeys.push_back(::concretelang::clientlib::BootstrapKeyParam{key});
|
||||
}
|
||||
auto keyswitchKeys =
|
||||
std::vector<::concretelang::clientlib::KeyswitchKeyParam>();
|
||||
for (auto key : programInfo.asReader().getKeyset().getLweKeyswitchKeys()) {
|
||||
keyswitchKeys.push_back(
|
||||
::concretelang::clientlib::KeyswitchKeyParam{key});
|
||||
}
|
||||
auto packingKeyswitchKeys =
|
||||
std::vector<::concretelang::clientlib::PackingKeyswitchKeyParam>();
|
||||
for (auto key :
|
||||
programInfo.asReader().getKeyset().getPackingKeyswitchKeys()) {
|
||||
packingKeyswitchKeys.push_back(
|
||||
::concretelang::clientlib::PackingKeyswitchKeyParam{key});
|
||||
}
|
||||
return ::concretelang::clientlib::ClientParameters{
|
||||
programInfo, secretKeys, boostrapKeys, keyswitchKeys,
|
||||
packingKeyswitchKeys};
|
||||
}
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::ProgramInfo>> getProgramInfo() {
|
||||
auto path = CompilerEngine::Library::getProgramInfoPath(outputPath);
|
||||
std::ifstream file(path);
|
||||
std::string content((std::istreambuf_iterator<char>(file)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
if (file.fail()) {
|
||||
return StreamStringError("Cannot read file: ") << path;
|
||||
}
|
||||
auto output = Message<concreteprotocol::ProgramInfo>();
|
||||
if (output.readJsonFromString(content).has_failure()) {
|
||||
return StreamStringError("Cannot read json string.");
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/// Load the the compilation result if circuit already compiled
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
loadCompilationResult() {
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->outputDirPath = outputPath;
|
||||
result->funcName = funcName;
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
llvm::Expected<CompilationFeedback>
|
||||
loadCompilationFeedback(LibraryCompilationResult &result) {
|
||||
auto path = CompilerEngine::Library::getCompilationFeedbackPath(
|
||||
result.outputDirPath);
|
||||
auto feedback = CompilationFeedback::load(path);
|
||||
if (feedback.has_error()) {
|
||||
return StreamStringError(feedback.error().mesg);
|
||||
}
|
||||
return feedback.value();
|
||||
}
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<::concretelang::clientlib::PublicResult>>
|
||||
serverCall(::concretelang::serverlib::ServerLambda lambda,
|
||||
::concretelang::clientlib::PublicArguments &args,
|
||||
::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
if (lambda.isSimulation) {
|
||||
return mlir::concretelang::StreamStringError(
|
||||
"Tried to perform server call on simulation lambda.");
|
||||
}
|
||||
EXPECTED_TRY(auto output, outcomeToExpected(lambda.circuit.call(
|
||||
evaluationKeys.keyset, args.values)));
|
||||
::concretelang::clientlib::PublicResult res{output};
|
||||
return std::make_unique<::concretelang::clientlib::PublicResult>(
|
||||
std::move(res));
|
||||
}
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<::concretelang::clientlib::PublicResult>>
|
||||
simulate(::concretelang::serverlib::ServerLambda lambda,
|
||||
::concretelang::clientlib::PublicArguments &args) {
|
||||
if (!lambda.isSimulation) {
|
||||
return mlir::concretelang::StreamStringError(
|
||||
"Tried to perform simulation on execution lambda.");
|
||||
}
|
||||
EXPECTED_TRY(auto output,
|
||||
outcomeToExpected(lambda.circuit.simulate(args.values)));
|
||||
::concretelang::clientlib::PublicResult res{output};
|
||||
return std::make_unique<::concretelang::clientlib::PublicResult>(
|
||||
std::move(res));
|
||||
}
|
||||
|
||||
/// Get path to shared library
|
||||
std::string getSharedLibPath() {
|
||||
return CompilerEngine::Library::getSharedLibraryPath(outputPath);
|
||||
}
|
||||
|
||||
/// Get path to client parameters file
|
||||
std::string getProgramInfoPath() {
|
||||
return CompilerEngine::Library::getProgramInfoPath(outputPath);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string outputPath;
|
||||
std::string funcName;
|
||||
std::string runtimeLibraryPath;
|
||||
/// Flags to select generated artifacts
|
||||
bool generateSharedLib;
|
||||
bool generateStaticLib;
|
||||
bool generateClientParameters;
|
||||
bool generateCompilationFeedback;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,46 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_CSPRNG_H
|
||||
#define CONCRETELANG_COMMON_CSPRNG_H
|
||||
|
||||
#include <assert.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
struct Csprng;
|
||||
struct CsprngVtable;
|
||||
|
||||
namespace concretelang {
|
||||
namespace csprng {
|
||||
|
||||
class CSPRNG {
|
||||
public:
|
||||
struct Csprng *ptr;
|
||||
const struct CsprngVtable *vtable;
|
||||
|
||||
CSPRNG() = delete;
|
||||
CSPRNG(CSPRNG &) = delete;
|
||||
|
||||
CSPRNG(CSPRNG &&other) : ptr(other.ptr), vtable(other.vtable) {
|
||||
assert(ptr != nullptr);
|
||||
other.ptr = nullptr;
|
||||
};
|
||||
|
||||
CSPRNG(Csprng *ptr, const CsprngVtable *vtable) : ptr(ptr), vtable(vtable){};
|
||||
};
|
||||
|
||||
class ConcreteCSPRNG : public CSPRNG {
|
||||
public:
|
||||
ConcreteCSPRNG(__uint128_t seed);
|
||||
ConcreteCSPRNG() = delete;
|
||||
ConcreteCSPRNG(ConcreteCSPRNG &) = delete;
|
||||
ConcreteCSPRNG(ConcreteCSPRNG &&other);
|
||||
~ConcreteCSPRNG();
|
||||
};
|
||||
|
||||
} // namespace csprng
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -5,11 +5,13 @@
|
||||
#ifndef CONCRETELANG_COMMON_ERROR_H
|
||||
#define CONCRETELANG_COMMON_ERROR_H
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include <string>
|
||||
|
||||
namespace concretelang {
|
||||
namespace error {
|
||||
|
||||
/// The type of error used throughout the client/server libs.
|
||||
class StringError {
|
||||
public:
|
||||
StringError(std::string mesg) : mesg(mesg){};
|
||||
@@ -37,6 +39,9 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
/// A result type used throughout the client/server libs.
|
||||
template <typename T> using Result = outcome::checked<T, StringError>;
|
||||
|
||||
} // namespace error
|
||||
} // namespace concretelang
|
||||
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_KEYS_H
|
||||
#define CONCRETELANG_COMMON_KEYS_H
|
||||
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Csprng.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include <memory>
|
||||
#include <stdlib.h>
|
||||
#include <vector>
|
||||
|
||||
using concretelang::csprng::CSPRNG;
|
||||
using concretelang::protocol::Message;
|
||||
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
inline void getApproval() {
|
||||
std::cerr << "DANGER: You are using an empty unsecure secret keys. Enter "
|
||||
"\"y\" to continue: ";
|
||||
char answer;
|
||||
std::cin >> answer;
|
||||
if (answer != 'y') {
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace concretelang {
|
||||
namespace keys {
|
||||
|
||||
/// An object representing an lwe Secret key
|
||||
class LweSecretKey {
|
||||
friend class Keyset;
|
||||
friend class KeysetCache;
|
||||
friend class LweBootstrapKey;
|
||||
friend class LweKeyswitchKey;
|
||||
friend class PackingKeyswitchKey;
|
||||
|
||||
public:
|
||||
LweSecretKey(Message<concreteprotocol::LweSecretKeyInfo> info,
|
||||
CSPRNG &csprng);
|
||||
LweSecretKey() = delete;
|
||||
LweSecretKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
Message<concreteprotocol::LweSecretKeyInfo> info)
|
||||
: buffer(buffer), info(info){};
|
||||
|
||||
static LweSecretKey
|
||||
fromProto(const Message<concreteprotocol::LweSecretKey> &proto);
|
||||
|
||||
Message<concreteprotocol::LweSecretKey> toProto() const;
|
||||
|
||||
const uint64_t *getRawPtr() const;
|
||||
|
||||
size_t getSize() const;
|
||||
|
||||
const Message<concreteprotocol::LweSecretKeyInfo> &getInfo() const;
|
||||
|
||||
const std::vector<uint64_t> &getBuffer() const;
|
||||
|
||||
typedef Message<concreteprotocol::LweSecretKeyInfo> InfoType;
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> buffer;
|
||||
Message<concreteprotocol::LweSecretKeyInfo> info;
|
||||
};
|
||||
|
||||
class LweBootstrapKey {
|
||||
friend class Keyset;
|
||||
|
||||
public:
|
||||
LweBootstrapKey(Message<concreteprotocol::LweBootstrapKeyInfo> info,
|
||||
const LweSecretKey &inputKey, const LweSecretKey &outputKey,
|
||||
CSPRNG &csprng);
|
||||
LweBootstrapKey() = delete;
|
||||
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
Message<concreteprotocol::LweBootstrapKeyInfo> info)
|
||||
: buffer(buffer), info(info){};
|
||||
|
||||
static LweBootstrapKey
|
||||
fromProto(const Message<concreteprotocol::LweBootstrapKey> &proto);
|
||||
|
||||
Message<concreteprotocol::LweBootstrapKey> toProto() const;
|
||||
|
||||
const uint64_t *getRawPtr() const;
|
||||
|
||||
size_t getSize() const;
|
||||
|
||||
const Message<concreteprotocol::LweBootstrapKeyInfo> &getInfo() const;
|
||||
|
||||
const std::vector<uint64_t> &getBuffer() const;
|
||||
|
||||
typedef Message<concreteprotocol::LweBootstrapKeyInfo> InfoType;
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> buffer;
|
||||
Message<concreteprotocol::LweBootstrapKeyInfo> info;
|
||||
};
|
||||
|
||||
class LweKeyswitchKey {
|
||||
friend class Keyset;
|
||||
|
||||
public:
|
||||
LweKeyswitchKey(Message<concreteprotocol::LweKeyswitchKeyInfo> info,
|
||||
const LweSecretKey &inputKey, const LweSecretKey &outputKey,
|
||||
CSPRNG &csprng);
|
||||
LweKeyswitchKey() = delete;
|
||||
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
Message<concreteprotocol::LweKeyswitchKeyInfo> info)
|
||||
: buffer(buffer), info(info){};
|
||||
|
||||
static LweKeyswitchKey
|
||||
fromProto(const Message<concreteprotocol::LweKeyswitchKey> &proto);
|
||||
|
||||
Message<concreteprotocol::LweKeyswitchKey> toProto() const;
|
||||
|
||||
const uint64_t *getRawPtr() const;
|
||||
|
||||
size_t getSize() const;
|
||||
|
||||
const Message<concreteprotocol::LweKeyswitchKeyInfo> &getInfo() const;
|
||||
|
||||
const std::vector<uint64_t> &getBuffer() const;
|
||||
|
||||
typedef Message<concreteprotocol::LweKeyswitchKeyInfo> InfoType;
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> buffer;
|
||||
Message<concreteprotocol::LweKeyswitchKeyInfo> info;
|
||||
};
|
||||
|
||||
class PackingKeyswitchKey {
|
||||
friend class Keyset;
|
||||
|
||||
public:
|
||||
PackingKeyswitchKey(Message<concreteprotocol::PackingKeyswitchKeyInfo> info,
|
||||
const LweSecretKey &inputKey,
|
||||
const LweSecretKey &outputKey, CSPRNG &csprng);
|
||||
PackingKeyswitchKey() = delete;
|
||||
PackingKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
Message<concreteprotocol::PackingKeyswitchKeyInfo> info)
|
||||
: buffer(buffer), info(info){};
|
||||
|
||||
static PackingKeyswitchKey
|
||||
fromProto(const Message<concreteprotocol::PackingKeyswitchKey> &proto);
|
||||
|
||||
Message<concreteprotocol::PackingKeyswitchKey> toProto() const;
|
||||
|
||||
const uint64_t *getRawPtr() const;
|
||||
|
||||
size_t getSize() const;
|
||||
|
||||
const Message<concreteprotocol::PackingKeyswitchKeyInfo> &getInfo() const;
|
||||
|
||||
const std::vector<uint64_t> &getBuffer() const;
|
||||
|
||||
typedef Message<concreteprotocol::PackingKeyswitchKeyInfo> InfoType;
|
||||
|
||||
private:
|
||||
std::shared_ptr<std::vector<uint64_t>> buffer;
|
||||
Message<concreteprotocol::PackingKeyswitchKeyInfo> info;
|
||||
};
|
||||
|
||||
} // namespace keys
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,79 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_KEYSETS_H
|
||||
#define CONCRETELANG_COMMON_KEYSETS_H
|
||||
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Keys.h"
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
|
||||
using concretelang::error::Result;
|
||||
using concretelang::error::StringError;
|
||||
using concretelang::keys::LweBootstrapKey;
|
||||
using concretelang::keys::LweKeyswitchKey;
|
||||
using concretelang::keys::LweSecretKey;
|
||||
using concretelang::keys::PackingKeyswitchKey;
|
||||
|
||||
namespace concretelang {
|
||||
namespace keysets {
|
||||
|
||||
struct ClientKeyset {
|
||||
std::vector<LweSecretKey> lweSecretKeys;
|
||||
|
||||
static ClientKeyset
|
||||
fromProto(const Message<concreteprotocol::ClientKeyset> &proto);
|
||||
|
||||
Message<concreteprotocol::ClientKeyset> toProto() const;
|
||||
};
|
||||
|
||||
struct ServerKeyset {
|
||||
std::vector<LweBootstrapKey> lweBootstrapKeys;
|
||||
std::vector<LweKeyswitchKey> lweKeyswitchKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
static ServerKeyset
|
||||
fromProto(const Message<concreteprotocol::ServerKeyset> &proto);
|
||||
|
||||
Message<concreteprotocol::ServerKeyset> toProto() const;
|
||||
};
|
||||
|
||||
struct Keyset {
|
||||
ServerKeyset server;
|
||||
ClientKeyset client;
|
||||
|
||||
/// Generates a fresh keyset from infos.
|
||||
Keyset(const Message<concreteprotocol::KeysetInfo> &info, CSPRNG &csprng);
|
||||
|
||||
Keyset(ServerKeyset server, ClientKeyset client)
|
||||
: server(server), client(client) {}
|
||||
|
||||
static Keyset fromProto(const Message<concreteprotocol::Keyset> &proto);
|
||||
|
||||
Message<concreteprotocol::Keyset> toProto() const;
|
||||
};
|
||||
|
||||
class KeysetCache {
|
||||
std::string backingDirectoryPath;
|
||||
|
||||
public:
|
||||
KeysetCache(std::string backingDirectoryPath);
|
||||
|
||||
Result<Keyset>
|
||||
getKeyset(const Message<concreteprotocol::KeysetInfo> &keysetInfo,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
private:
|
||||
KeysetCache() = default;
|
||||
};
|
||||
|
||||
} // namespace keysets
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,341 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_PROTOCOL_H
|
||||
#define CONCRETELANG_COMMON_PROTOCOL_H
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include "capnp/common.h"
|
||||
#include "capnp/compat/json.h"
|
||||
#include "capnp/message.h"
|
||||
#include "capnp/serialize-packed.h"
|
||||
#include "capnp/serialize.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "kj/common.h"
|
||||
#include "kj/exception.h"
|
||||
#include "kj/io.h"
|
||||
#include "kj/std/iostream.h"
|
||||
#include "kj/string.h"
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
using concretelang::error::Result;
|
||||
using concretelang::error::StringError;
|
||||
|
||||
const uint64_t MAX_SEGMENT_SIZE = capnp::MAX_SEGMENT_WORDS;
|
||||
|
||||
namespace concretelang {
|
||||
namespace protocol {
|
||||
|
||||
/// Arena carrying capnp messages.
|
||||
///
|
||||
/// This type packs a message with an arena used to store the data in a single
|
||||
/// object.
|
||||
///
|
||||
/// Rationale:
|
||||
/// ----------
|
||||
///
|
||||
/// Capnproto is a performance-oriented serialization framework, which
|
||||
/// approaches the problem by constructing a memory representation that is
|
||||
/// already equivalent to the serialized binary representation.
|
||||
///
|
||||
/// To make that possible and as fast as possible, they use an arena-passing
|
||||
/// programming model, which makes serialization fast, but is also pretty
|
||||
/// invasive:
|
||||
/// + The top-level message being constructed must be known in advanced, so as
|
||||
/// to properly initialize the MallocMessageBuilder.
|
||||
/// + The parent message builder must be passed to a function creating a child
|
||||
/// message, so as to properly initialize the field, and fill it.
|
||||
/// + The arena must be managed at the top level to ensure that the messages are
|
||||
/// always pointing to valid memory locations.
|
||||
///
|
||||
/// In the compiler, we use the concrete-protocol messages for slightly more
|
||||
/// than just serialization, and having a less-than-optimal serialization speed
|
||||
/// is not a big problem for now. For this reason, it makes sense to make the
|
||||
/// use of the capnp messages a little more ergonomic, which is what this type
|
||||
/// allows.
|
||||
template <typename MessageType> struct Message {
|
||||
|
||||
Message() : message(nullptr) {
|
||||
regionBuilder = new capnp::MallocMessageBuilder();
|
||||
message = regionBuilder->initRoot<MessageType>();
|
||||
}
|
||||
|
||||
Message(const typename MessageType::Reader &reader) : message(nullptr) {
|
||||
regionBuilder = new capnp::MallocMessageBuilder(
|
||||
std::min(reader.totalSize().wordCount, MAX_SEGMENT_SIZE),
|
||||
capnp::AllocationStrategy::FIXED_SIZE);
|
||||
regionBuilder->setRoot(reader);
|
||||
message = regionBuilder->getRoot<MessageType>();
|
||||
}
|
||||
|
||||
Message(const Message &input) : message(nullptr) {
|
||||
regionBuilder = new capnp::MallocMessageBuilder(
|
||||
std::min(input.message.asReader().totalSize().wordCount,
|
||||
MAX_SEGMENT_SIZE),
|
||||
capnp::AllocationStrategy::FIXED_SIZE);
|
||||
regionBuilder->setRoot(input.message.asReader());
|
||||
message = regionBuilder->getRoot<MessageType>();
|
||||
}
|
||||
|
||||
Message &operator=(const typename MessageType::Reader &reader) {
|
||||
if (regionBuilder) {
|
||||
delete regionBuilder;
|
||||
}
|
||||
regionBuilder = new capnp::MallocMessageBuilder(
|
||||
std::min(reader.totalSize().wordCount, MAX_SEGMENT_SIZE),
|
||||
capnp::AllocationStrategy::FIXED_SIZE);
|
||||
regionBuilder->setRoot(reader);
|
||||
message = regionBuilder->getRoot<MessageType>();
|
||||
return *this;
|
||||
}
|
||||
|
||||
Message &operator=(const Message &input) {
|
||||
if (this != &input) {
|
||||
if (regionBuilder) {
|
||||
delete regionBuilder;
|
||||
}
|
||||
regionBuilder = new capnp::MallocMessageBuilder(
|
||||
std::min(input.message.asReader().totalSize().wordCount,
|
||||
MAX_SEGMENT_SIZE),
|
||||
capnp::AllocationStrategy::FIXED_SIZE);
|
||||
regionBuilder->setRoot(input.message.asReader());
|
||||
message = regionBuilder->getRoot<MessageType>();
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
Message(Message &&input) : message(nullptr) {
|
||||
regionBuilder = input.regionBuilder;
|
||||
message = input.message;
|
||||
input.regionBuilder = nullptr;
|
||||
}
|
||||
|
||||
Message &operator=(Message &&input) {
|
||||
if (this != &input) {
|
||||
if (regionBuilder) {
|
||||
delete regionBuilder;
|
||||
}
|
||||
regionBuilder = input.regionBuilder;
|
||||
message = input.message;
|
||||
input.regionBuilder = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
~Message() {
|
||||
if (regionBuilder) {
|
||||
delete regionBuilder;
|
||||
}
|
||||
}
|
||||
|
||||
typename MessageType::Reader asReader() const { return message.asReader(); }
|
||||
|
||||
typename MessageType::Builder asBuilder() { return message; }
|
||||
|
||||
Result<void> writeBinaryToFd(int fd) const {
|
||||
try {
|
||||
capnp::writeMessageToFd(fd, *regionBuilder);
|
||||
return outcome::success();
|
||||
} catch (const kj::Exception &e) {
|
||||
return StringError("Failed to write message to file descriptor: ")
|
||||
<< e.getDescription().cStr();
|
||||
}
|
||||
}
|
||||
|
||||
Result<void> writeBinaryToOstream(std::ostream &ostream) const {
|
||||
try {
|
||||
kj::std::StdOutputStream kjOstream(ostream);
|
||||
capnp::writeMessage(kjOstream, *regionBuilder);
|
||||
} catch (const kj::Exception &e) {
|
||||
return StringError("Failed to write message to ostream: ")
|
||||
<< e.getDescription().cStr();
|
||||
}
|
||||
ostream.flush();
|
||||
if (!ostream.good()) {
|
||||
return StringError(
|
||||
"Failed to write message to ostream. Ended up in bad state.");
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
Result<std::string> writeBinaryToString() const {
|
||||
auto ostream = std::ostringstream();
|
||||
OUTCOME_TRYV(this->writeBinaryToOstream(ostream));
|
||||
return outcome::success(ostream.str());
|
||||
}
|
||||
|
||||
Result<std::string> writeJsonToString() const {
|
||||
try {
|
||||
capnp::JsonCodec json;
|
||||
kj::String output = json.encode(this->message.asReader());
|
||||
return outcome::success(std::string(output.cStr(), output.size()));
|
||||
} catch (const kj::Exception &e) {
|
||||
return outcome::failure(
|
||||
StringError("Failed to write message to json string: ")
|
||||
<< e.getDescription().cStr());
|
||||
}
|
||||
}
|
||||
|
||||
Result<void> readBinaryFromFd(int fd) {
|
||||
try {
|
||||
capnp::readMessageCopyFromFd(fd, *regionBuilder);
|
||||
this->message = regionBuilder->getRoot<MessageType>();
|
||||
return outcome::success();
|
||||
} catch (const kj::Exception &e) {
|
||||
return StringError("Failed to read message from file descriptor: ")
|
||||
<< e.getDescription().cStr();
|
||||
}
|
||||
}
|
||||
|
||||
Result<void>
|
||||
readBinaryFromIstream(std::istream &istream,
|
||||
capnp::ReaderOptions options = capnp::ReaderOptions()) {
|
||||
try {
|
||||
kj::std::StdInputStream kjIstream(istream);
|
||||
capnp::readMessageCopy(kjIstream, *regionBuilder, options);
|
||||
this->message = regionBuilder->getRoot<MessageType>();
|
||||
return outcome::success();
|
||||
} catch (const kj::Exception &e) {
|
||||
return StringError("Failed to read message from istream: ")
|
||||
<< e.getDescription().cStr();
|
||||
}
|
||||
}
|
||||
|
||||
Result<void>
|
||||
readBinaryFromString(const std::string &input,
|
||||
capnp::ReaderOptions options = capnp::ReaderOptions()) {
|
||||
auto istream = std::istringstream(input);
|
||||
return this->readBinaryFromIstream(istream, options);
|
||||
}
|
||||
|
||||
Result<void> readJsonFromString(const std::string &input) {
|
||||
try {
|
||||
capnp::JsonCodec json;
|
||||
kj::StringPtr stringPointer(input.c_str(), input.size());
|
||||
this->message = this->regionBuilder->template initRoot<MessageType>();
|
||||
json.decode(stringPointer, this->message);
|
||||
return outcome::success();
|
||||
} catch (const kj::Exception &e) {
|
||||
return StringError("Failed to read message from json string: ")
|
||||
<< e.getDescription().cStr();
|
||||
}
|
||||
}
|
||||
|
||||
std::string debugString() const { return writeJsonToString().value(); }
|
||||
|
||||
private:
|
||||
capnp::MallocMessageBuilder *regionBuilder;
|
||||
typename MessageType::Builder message;
|
||||
};
|
||||
|
||||
template struct Message<concreteprotocol::ProgramInfo>;
|
||||
template struct Message<concreteprotocol::CircuitEncodingInfo>;
|
||||
template struct Message<concreteprotocol::Value>;
|
||||
template struct Message<concreteprotocol::GateInfo>;
|
||||
|
||||
/// Helper function turning a vector of integers to a payload.
|
||||
template <typename T>
|
||||
Message<concreteprotocol::Payload>
|
||||
vectorToProtoPayload(const std::vector<T> &input) {
|
||||
auto output = Message<concreteprotocol::Payload>();
|
||||
auto elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
|
||||
auto remainingElms = input.size() % elmsPerBlob;
|
||||
auto nbBlobs = (input.size() / elmsPerBlob) + (remainingElms > 0);
|
||||
auto dataBuilder = output.asBuilder().initData(nbBlobs);
|
||||
// Process all but the last blob, which store as much as `Data` allow.
|
||||
if (nbBlobs > 1) {
|
||||
for (size_t blobIndex = 0; blobIndex < nbBlobs - 1; blobIndex++) {
|
||||
auto blobPtr = input.data() + blobIndex * elmsPerBlob;
|
||||
auto blobLen = elmsPerBlob * sizeof(T);
|
||||
dataBuilder.set(
|
||||
blobIndex,
|
||||
capnp::Data::Reader(reinterpret_cast<const unsigned char *>(blobPtr),
|
||||
blobLen));
|
||||
}
|
||||
}
|
||||
|
||||
// Process the last blob which store the remainder.
|
||||
if (nbBlobs > 0) {
|
||||
auto lastBlobIndex = nbBlobs - 1;
|
||||
auto lastBlobPtr = input.data() + lastBlobIndex * elmsPerBlob;
|
||||
auto lastBlobLen = remainingElms * sizeof(T);
|
||||
dataBuilder.set(
|
||||
lastBlobIndex,
|
||||
capnp::Data::Reader(
|
||||
reinterpret_cast<const unsigned char *>(lastBlobPtr), lastBlobLen));
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
/// Helper function turning a payload to a vector of integers.
|
||||
template <typename T>
|
||||
std::vector<T>
|
||||
protoPayloadToVector(const Message<concreteprotocol::Payload> &input) {
|
||||
auto payloadData = input.asReader().getData();
|
||||
auto elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
|
||||
auto totalPayloadSize = 0;
|
||||
for (auto blob : payloadData) {
|
||||
totalPayloadSize += blob.size();
|
||||
}
|
||||
assert(totalPayloadSize % sizeof(T) == 0);
|
||||
auto dataSize = totalPayloadSize / sizeof(T);
|
||||
auto output = std::vector<T>();
|
||||
output.resize(dataSize);
|
||||
for (size_t blobIndex = 0; blobIndex < payloadData.size(); blobIndex++) {
|
||||
auto blobData = payloadData[blobIndex];
|
||||
auto blobPtr = output.data() + blobIndex * elmsPerBlob;
|
||||
std::memcpy(blobPtr, blobData.begin(), blobData.size());
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/// Helper function turning a payload to a shared vector of integers on the
|
||||
/// heap.
|
||||
template <typename T>
|
||||
std::shared_ptr<std::vector<T>>
|
||||
protoPayloadToSharedVector(const Message<concreteprotocol::Payload> &input) {
|
||||
auto payloadData = input.asReader().getData();
|
||||
size_t elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
|
||||
size_t totalPayloadSize = 0;
|
||||
for (auto blob : payloadData) {
|
||||
totalPayloadSize += blob.size();
|
||||
}
|
||||
assert(totalPayloadSize % sizeof(T) == 0);
|
||||
size_t dataSize = totalPayloadSize / sizeof(T);
|
||||
auto output = std::make_shared<std::vector<T>>();
|
||||
output->resize(dataSize);
|
||||
for (size_t blobIndex = 0; blobIndex < payloadData.size(); blobIndex++) {
|
||||
auto blobData = payloadData[blobIndex];
|
||||
auto blobPtr = output->data() + blobIndex * elmsPerBlob;
|
||||
std::memcpy(blobPtr, blobData.begin(), blobData.size());
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
/// Helper function turning a protocol `Shape` object into a vector of
|
||||
/// dimensions.
|
||||
std::vector<size_t>
|
||||
protoShapeToDimensions(const Message<concreteprotocol::Shape> &shape);
|
||||
|
||||
/// Helper function turning a protocol `Shape` object into a vector of
|
||||
/// dimensions.
|
||||
Message<concreteprotocol::Shape>
|
||||
dimensionsToProtoShape(const std::vector<size_t> &input);
|
||||
|
||||
template <typename MessageType> size_t hashMessage(Message<MessageType> &mess);
|
||||
|
||||
} // namespace protocol
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,90 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_TRANSFORMERS_H
|
||||
#define CONCRETELANG_COMMON_TRANSFORMERS_H
|
||||
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include "concretelang/Common/Values.h"
|
||||
#include <memory>
|
||||
#include <stdlib.h>
|
||||
|
||||
using concretelang::error::Result;
|
||||
using concretelang::keysets::ClientKeyset;
|
||||
using concretelang::values::Tensor;
|
||||
using concretelang::values::TransportValue;
|
||||
using concretelang::values::Value;
|
||||
|
||||
namespace concretelang {
|
||||
namespace transformers {
|
||||
|
||||
/// A type for input transformers, that is, functions running on the client
|
||||
/// side, that prepare a Value to be sent to the server as a TransportValue.
|
||||
typedef std::function<Result<TransportValue>(Value)> InputTransformer;
|
||||
|
||||
/// A type for output transformers, that is, functions running on the client
|
||||
/// side, that process a TransportValue fetched from the server to be used as a
|
||||
/// Value.
|
||||
typedef std::function<Result<Value>(TransportValue)> OutputTransformer;
|
||||
|
||||
/// A type for arguments transformers, that is, functions running on the server
|
||||
/// side, that transform a TransportValue fetched from the client, to be used as
|
||||
/// argument in a circuit call.
|
||||
typedef std::function<Result<Value>(TransportValue)> ArgTransformer;
|
||||
|
||||
/// A type for return transformers, that is, functions running on the server
|
||||
/// side, that transform a value returned from circuit call into a
|
||||
/// TransportValue to be sent to the client.
|
||||
typedef std::function<Result<TransportValue>(Value)> ReturnTransformer;
|
||||
|
||||
/// A factory static class that generates transformers.
|
||||
class TransformerFactory {
|
||||
public:
|
||||
static Result<InputTransformer>
|
||||
getIndexInputTransformer(Message<concreteprotocol::GateInfo> gateInfo);
|
||||
|
||||
static Result<OutputTransformer>
|
||||
getIndexOutputTransformer(Message<concreteprotocol::GateInfo> gateInfo);
|
||||
|
||||
static Result<ArgTransformer>
|
||||
getIndexArgTransformer(Message<concreteprotocol::GateInfo> gateInfo);
|
||||
|
||||
static Result<ReturnTransformer>
|
||||
getIndexReturnTransformer(Message<concreteprotocol::GateInfo> gateInfo);
|
||||
|
||||
static Result<InputTransformer>
|
||||
getPlaintextInputTransformer(Message<concreteprotocol::GateInfo> gateInfo);
|
||||
|
||||
static Result<OutputTransformer>
|
||||
getPlaintextOutputTransformer(Message<concreteprotocol::GateInfo> gateInfo);
|
||||
|
||||
static Result<ArgTransformer>
|
||||
getPlaintextArgTransformer(Message<concreteprotocol::GateInfo> gateInfo);
|
||||
|
||||
static Result<ReturnTransformer>
|
||||
getPlaintextReturnTransformer(Message<concreteprotocol::GateInfo> gateInfo);
|
||||
|
||||
static Result<InputTransformer> getLweCiphertextInputTransformer(
|
||||
ClientKeyset keyset, Message<concreteprotocol::GateInfo> gateInfo,
|
||||
std::shared_ptr<CSPRNG> csprng, bool useSimulation);
|
||||
|
||||
static Result<OutputTransformer> getLweCiphertextOutputTransformer(
|
||||
ClientKeyset keyset, Message<concreteprotocol::GateInfo> gateInfo,
|
||||
bool useSimulation);
|
||||
|
||||
static Result<ArgTransformer>
|
||||
getLweCiphertextArgTransformer(Message<concreteprotocol::GateInfo> gateInfo,
|
||||
bool useSimulation);
|
||||
|
||||
static Result<ReturnTransformer> getLweCiphertextReturnTransformer(
|
||||
Message<concreteprotocol::GateInfo> gateInfo, bool useSimulation);
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,217 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_COMMON_VALUES_H
|
||||
#define CONCRETELANG_COMMON_VALUES_H
|
||||
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <initializer_list>
|
||||
#include <optional>
|
||||
#include <stdlib.h>
|
||||
#include <variant>
|
||||
|
||||
using concretelang::error::Result;
|
||||
using concretelang::error::StringError;
|
||||
using concretelang::protocol::dimensionsToProtoShape;
|
||||
using concretelang::protocol::Message;
|
||||
using concretelang::protocol::protoPayloadToVector;
|
||||
using concretelang::protocol::protoShapeToDimensions;
|
||||
using concretelang::protocol::vectorToProtoPayload;
|
||||
|
||||
namespace concretelang {
|
||||
namespace values {
|
||||
|
||||
/// A type for public (encrypted or not) values, that can be safely transported
|
||||
/// between client and server to for execution.
|
||||
typedef Message<concreteprotocol::Value> TransportValue;
|
||||
|
||||
/// A type for tensor data.
|
||||
template <typename T> struct Tensor {
|
||||
std::vector<T> values;
|
||||
std::vector<size_t> dimensions;
|
||||
|
||||
Tensor<T>() = default;
|
||||
Tensor<T>(std::vector<T> values, std::vector<size_t> dimensions)
|
||||
: values(values), dimensions(dimensions) {}
|
||||
|
||||
/// Creates an tensor with the shape described by the input dimensions, filled
|
||||
/// with zeros.
|
||||
static Tensor<T> fromDimensions(std::vector<size_t> &dimensions) {
|
||||
uint32_t length = 1;
|
||||
for (auto dim : dimensions) {
|
||||
length *= dim;
|
||||
}
|
||||
auto values = std::vector<T>(length);
|
||||
for (auto &val : values) {
|
||||
*val = 0;
|
||||
}
|
||||
return Tensor{values, dimensions};
|
||||
}
|
||||
|
||||
/// Conversion constructor from a scalar value.
|
||||
Tensor<T>(T in) { this->values.push_back(in); }
|
||||
|
||||
/// Constructor from initializer lists of values and dimensions.
|
||||
Tensor<T>(std::initializer_list<T> values,
|
||||
std::initializer_list<size_t> dimensions) {
|
||||
size_t count = 1;
|
||||
for (auto dim : dimensions) {
|
||||
count *= dim;
|
||||
}
|
||||
assert(values.size() == count);
|
||||
for (auto val : values) {
|
||||
this->values.push_back(val);
|
||||
}
|
||||
for (auto dim : dimensions) {
|
||||
this->dimensions.push_back(dim);
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(const Tensor<T> &b) const {
|
||||
return this->values == b.values && this->dimensions == b.dimensions;
|
||||
}
|
||||
|
||||
Tensor<T> operator-(T b) const {
|
||||
Tensor<T> out = *this;
|
||||
for (size_t i = 0; i < out.values.size(); i++) {
|
||||
out.values[i] -= b;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor<T> operator-(Tensor<T> b) const {
|
||||
assert(this->dimensions == b.dimensions);
|
||||
Tensor<T> out = *this;
|
||||
for (size_t i = 0; i < out.values.size(); i++) {
|
||||
out.values[i] -= b.values[i];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor<T> operator+(T b) const {
|
||||
Tensor<T> out = *this;
|
||||
for (size_t i = 0; i < out.values.size(); i++) {
|
||||
out.values[i] += b;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor<T> operator+(Tensor<T> b) const {
|
||||
assert(this->dimensions == b.dimensions);
|
||||
Tensor<T> out = *this;
|
||||
for (size_t i = 0; i < out.values.size(); i++) {
|
||||
out.values[i] += b.values[i];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor<T> operator*(T b) const {
|
||||
Tensor<T> out = *this;
|
||||
for (size_t i = 0; i < out.values.size(); i++) {
|
||||
out.values[i] *= b;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
Tensor<T> operator*(Tensor<T> b) const {
|
||||
assert(this->dimensions == b.dimensions);
|
||||
Tensor<T> out = *this;
|
||||
for (size_t i = 0; i < out.values.size(); i++) {
|
||||
out.values[i] *= b.values[i];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
T &operator[](int index) { return this->values[index]; }
|
||||
|
||||
template <typename U> explicit operator Tensor<U>() const {
|
||||
Tensor<U> output;
|
||||
output.dimensions = this->dimensions;
|
||||
for (auto v : this->values) {
|
||||
output.values.push_back((U)v);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
bool isScalar() const { return dimensions.empty(); }
|
||||
};
|
||||
|
||||
/// A type for tensor data of varying precisions. Mainly use to manipulate
|
||||
struct Value {
|
||||
friend class ClientCircuit;
|
||||
|
||||
std::variant<Tensor<uint8_t>, Tensor<int8_t>, Tensor<uint16_t>,
|
||||
Tensor<int16_t>, Tensor<uint32_t>, Tensor<int32_t>,
|
||||
Tensor<uint64_t>, Tensor<int64_t>>
|
||||
inner;
|
||||
Value() = default;
|
||||
Value(Tensor<uint8_t> inner) : inner(inner){};
|
||||
Value(Tensor<uint16_t> inner) : inner(inner){};
|
||||
Value(Tensor<uint32_t> inner) : inner(inner){};
|
||||
Value(Tensor<uint64_t> inner) : inner(inner){};
|
||||
Value(Tensor<int8_t> inner) : inner(inner){};
|
||||
Value(Tensor<int16_t> inner) : inner(inner){};
|
||||
Value(Tensor<int32_t> inner) : inner(inner){};
|
||||
Value(Tensor<int64_t> inner) : inner(inner){};
|
||||
|
||||
/// Turns a server value to a client value, without interpreting the kind of
|
||||
/// value.
|
||||
static Value fromRawTransportValue(TransportValue transportVal);
|
||||
|
||||
/// Turns a client value to a raw (without kind info attached) server value.
|
||||
TransportValue intoRawTransportValue() const;
|
||||
|
||||
bool operator==(const Value &b) const;
|
||||
|
||||
uint32_t getIntegerPrecision() const;
|
||||
|
||||
bool isSigned() const;
|
||||
|
||||
Message<concreteprotocol::Payload> intoProtoPayload() const;
|
||||
|
||||
Message<concreteprotocol::Shape> intoProtoShape() const;
|
||||
|
||||
std::vector<size_t> getDimensions() const;
|
||||
|
||||
size_t getLength() const;
|
||||
|
||||
template <typename T> bool hasElementType() const {
|
||||
return std::holds_alternative<Tensor<T>>(inner);
|
||||
}
|
||||
|
||||
template <typename T> std::optional<Tensor<T>> getTensor() const {
|
||||
if (!hasElementType<T>()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::get<Tensor<T>>(inner);
|
||||
}
|
||||
|
||||
template <typename T> Tensor<T> *getTensorPtr() {
|
||||
if (!hasElementType<T>()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &std::get<Tensor<T>>(inner);
|
||||
}
|
||||
|
||||
bool
|
||||
isCompatibleWithShape(const Message<concreteprotocol::Shape> &shape) const;
|
||||
|
||||
bool isScalar() const;
|
||||
|
||||
Value toUnsigned() const;
|
||||
|
||||
Value toSigned() const;
|
||||
};
|
||||
|
||||
size_t getCorrespondingPrecision(size_t originalPrecision);
|
||||
|
||||
} // namespace values
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -59,14 +59,14 @@ def TFHE_PackingKeyswitchKeyAttr: TFHE_Attr<"GLWEPackingKeyswitchKey", "pksk"> {
|
||||
"mlir::concretelang::TFHE::GLWESecretKey":$inputKey,
|
||||
"mlir::concretelang::TFHE::GLWESecretKey":$outputKey,
|
||||
"int" : $outputPolySize,
|
||||
"int" : $inputLweDim,
|
||||
"int" : $innerLweDim,
|
||||
"int" : $glweDim,
|
||||
"int" : $levels,
|
||||
"int" : $baseLog,
|
||||
DefaultValuedParameter<"int", "-1">: $index
|
||||
);
|
||||
|
||||
let assemblyFormat = " (`[` $index^ `]` )? `<` $inputKey `,` $outputKey`,` $outputPolySize`,` $inputLweDim `,` $glweDim `,` $levels `,` $baseLog `>`";
|
||||
let assemblyFormat = " (`[` $index^ `]` )? `<` $inputKey `,` $outputKey`,` $outputPolySize`,` $innerLweDim `,` $glweDim `,` $levels `,` $baseLog `>`";
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -6,16 +6,16 @@
|
||||
#ifndef CONCRETELANG_RUNTIME_CONTEXT_H
|
||||
#define CONCRETELANG_RUNTIME_CONTEXT_H
|
||||
|
||||
#include "concrete-cpu.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include <assert.h>
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <pthread.h>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
#include "concrete-cpu.h"
|
||||
using ::concretelang::keysets::ServerKeyset;
|
||||
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
#include "bootstrap.h"
|
||||
@@ -40,7 +40,7 @@ typedef struct FFT {
|
||||
typedef struct RuntimeContext {
|
||||
|
||||
RuntimeContext() = delete;
|
||||
RuntimeContext(::concretelang::clientlib::EvaluationKeys evaluationKeys);
|
||||
RuntimeContext(ServerKeyset serverKeyset);
|
||||
~RuntimeContext() {
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
@@ -53,7 +53,7 @@ typedef struct RuntimeContext {
|
||||
};
|
||||
|
||||
const uint64_t *keyswitch_key_buffer(size_t keyId) {
|
||||
return evaluationKeys.getKeyswitchKey(keyId).buffer();
|
||||
return serverKeyset.lweKeyswitchKeys[keyId].getRawPtr();
|
||||
}
|
||||
|
||||
const double *fourier_bootstrap_key_buffer(size_t keyId) {
|
||||
@@ -61,17 +61,15 @@ typedef struct RuntimeContext {
|
||||
}
|
||||
|
||||
const uint64_t *fp_keyswitch_key_buffer(size_t keyId) {
|
||||
return evaluationKeys.getPackingKeyswitchKey(keyId).buffer();
|
||||
return serverKeyset.packingKeyswitchKeys[keyId].getRawPtr();
|
||||
}
|
||||
|
||||
const struct Fft *fft(size_t keyId) { return ffts[keyId].fft; }
|
||||
|
||||
const ::concretelang::clientlib::EvaluationKeys getKeys() const {
|
||||
return evaluationKeys;
|
||||
}
|
||||
const ServerKeyset getKeys() const { return serverKeyset; }
|
||||
|
||||
private:
|
||||
::concretelang::clientlib::EvaluationKeys evaluationKeys;
|
||||
ServerKeyset serverKeyset;
|
||||
std::vector<std::shared_ptr<std::vector<double>>> fourier_bootstrap_keys;
|
||||
std::vector<FFT> ffts;
|
||||
|
||||
@@ -89,15 +87,15 @@ public:
|
||||
return bsk_gpu[gpu_idx];
|
||||
}
|
||||
|
||||
auto bsk = evaluationKeys.getBootstrapKey(0);
|
||||
auto bsk = serverKeyset.lweBootstrapKeys[0];
|
||||
|
||||
size_t bsk_buffer_len = bsk.size();
|
||||
size_t bsk_buffer_len = bsk.getBuffer().size();
|
||||
size_t bsk_gpu_buffer_size = bsk_buffer_len * sizeof(double);
|
||||
|
||||
void *bsk_gpu_tmp =
|
||||
cuda_malloc_async(bsk_gpu_buffer_size, (cudaStream_t *)stream, gpu_idx);
|
||||
cuda_convert_lwe_bootstrap_key_64(
|
||||
bsk_gpu_tmp, const_cast<uint64_t *>(bsk.buffer()),
|
||||
bsk_gpu_tmp, const_cast<uint64_t *>(bsk.getBuffer().data()),
|
||||
(cudaStream_t *)stream, gpu_idx, input_lwe_dim, glwe_dim, level,
|
||||
poly_size);
|
||||
// Synchronization here is not optional as it works with mutex to
|
||||
@@ -118,14 +116,15 @@ public:
|
||||
if (ksk_gpu[gpu_idx] != nullptr) {
|
||||
return ksk_gpu[gpu_idx];
|
||||
}
|
||||
auto ksk = evaluationKeys.getKeyswitchKey(0);
|
||||
auto ksk = serverKeyset.lweKeyswitchKeys[0];
|
||||
|
||||
size_t ksk_buffer_size = sizeof(uint64_t) * ksk.size();
|
||||
size_t ksk_buffer_size = sizeof(uint64_t) * ksk.getBuffer().size();
|
||||
|
||||
void *ksk_gpu_tmp =
|
||||
cuda_malloc_async(ksk_buffer_size, (cudaStream_t *)stream, gpu_idx);
|
||||
|
||||
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, const_cast<uint64_t *>(ksk.buffer()),
|
||||
cuda_memcpy_async_to_gpu(ksk_gpu_tmp,
|
||||
const_cast<uint64_t *>(ksk.getBuffer().data()),
|
||||
ksk_buffer_size, (cudaStream_t *)stream, gpu_idx);
|
||||
// Synchronization here is not optional as it works with mutex to
|
||||
// prevent other GPU streams from reading partially copied keys.
|
||||
|
||||
@@ -29,7 +29,6 @@
|
||||
|
||||
#include <mlir/ExecutionEngine/CRunnerUtils.h>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/Runtime/dfr_debug_interface.h"
|
||||
|
||||
@@ -15,27 +15,30 @@
|
||||
#include <hpx/modules/collectives.hpp>
|
||||
#include <hpx/modules/serialization.hpp>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Keys.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
|
||||
using concretelang::keys::LweBootstrapKey;
|
||||
using concretelang::keys::LweKeyswitchKey;
|
||||
using concretelang::keys::LweSecretKey;
|
||||
using concretelang::keys::PackingKeyswitchKey;
|
||||
using concretelang::keysets::ServerKeyset;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace dfr {
|
||||
|
||||
using namespace ::concretelang::clientlib;
|
||||
|
||||
struct RuntimeContextManager;
|
||||
namespace {
|
||||
static void *dl_handle;
|
||||
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
|
||||
} // namespace
|
||||
|
||||
template <typename LweKeyType, typename KeyParamType> struct KeyWrapper {
|
||||
template <typename LweKeyType> struct KeyWrapper {
|
||||
std::vector<LweKeyType> keys;
|
||||
|
||||
KeyWrapper() {}
|
||||
@@ -51,38 +54,44 @@ template <typename LweKeyType, typename KeyParamType> struct KeyWrapper {
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
ar << (size_t)keys.size();
|
||||
for (auto k : keys) {
|
||||
auto params = k.parameters();
|
||||
size_t param_size = sizeof(KeyParamType);
|
||||
ar << hpx::serialization::make_array((char *)¶ms, param_size);
|
||||
ar << (size_t)k.size();
|
||||
ar << hpx::serialization::make_array(k.buffer(), k.size());
|
||||
auto info = k.getInfo();
|
||||
auto maybe_info_string = info.writeBinaryToString();
|
||||
assert(maybe_info_string.has_value());
|
||||
auto info_string = maybe_info_string.value();
|
||||
ar << hpx::serialization::make_array(info_string.c_str(),
|
||||
info_string.size());
|
||||
ar << (size_t)k.getBuffer().size();
|
||||
ar << hpx::serialization::make_array(k.getBuffer().data(),
|
||||
k.getBuffer().size());
|
||||
}
|
||||
}
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
size_t num_keys;
|
||||
ar >> num_keys;
|
||||
for (uint i = 0; i < num_keys; ++i) {
|
||||
KeyParamType params;
|
||||
size_t param_size = sizeof(params);
|
||||
ar >> hpx::serialization::make_array((char *)¶ms, param_size);
|
||||
std::string info_string;
|
||||
ar >> info_string;
|
||||
typename LweKeyType::InfoType info;
|
||||
assert(info.readBinaryFromString(info_string).has_value());
|
||||
size_t key_size;
|
||||
ar >> key_size;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
buffer->resize(key_size);
|
||||
ar >> hpx::serialization::make_array(buffer->data(), key_size);
|
||||
keys.push_back(LweKeyType(buffer, params));
|
||||
keys.push_back(LweKeyType(buffer, info));
|
||||
}
|
||||
}
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
};
|
||||
|
||||
template <typename LweKeyType, typename KeyParamType>
|
||||
bool operator==(const KeyWrapper<LweKeyType, KeyParamType> &lhs,
|
||||
const KeyWrapper<LweKeyType, KeyParamType> &rhs) {
|
||||
template <typename LweKeyType>
|
||||
bool operator==(const KeyWrapper<LweKeyType> &lhs,
|
||||
|
||||
const KeyWrapper<LweKeyType> &rhs) {
|
||||
if (lhs.keys.size() != rhs.keys.size())
|
||||
return false;
|
||||
for (size_t i = 0; i < lhs.keys.size(); ++i)
|
||||
if (lhs.keys[i].buffer() != rhs.keys[i].buffer())
|
||||
if (lhs.keys[i].getBuffer() != rhs.keys[i].getBuffer())
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
@@ -110,21 +119,21 @@ struct RuntimeContextManager {
|
||||
if (_dfr_is_root_node()) {
|
||||
RuntimeContext *context = (RuntimeContext *)ctx;
|
||||
|
||||
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam> kskw(
|
||||
context->getKeys().getKeyswitchKeys());
|
||||
KeyWrapper<LweBootstrapKey, BootstrapKeyParam> bskw(
|
||||
context->getKeys().getBootstrapKeys());
|
||||
KeyWrapper<LweKeyswitchKey> kskw(context->getKeys().lweKeyswitchKeys);
|
||||
KeyWrapper<LweBootstrapKey> bskw(context->getKeys().lweBootstrapKeys);
|
||||
hpx::collectives::broadcast_to("ksk_keystore", kskw);
|
||||
hpx::collectives::broadcast_to("bsk_keystore", bskw);
|
||||
} else {
|
||||
auto kskFut = hpx::collectives::broadcast_from<
|
||||
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam>>("ksk_keystore");
|
||||
auto bskFut = hpx::collectives::broadcast_from<
|
||||
KeyWrapper<LweBootstrapKey, BootstrapKeyParam>>("bsk_keystore");
|
||||
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam> kskw = kskFut.get();
|
||||
KeyWrapper<LweBootstrapKey, BootstrapKeyParam> bskw = bskFut.get();
|
||||
auto kskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey>>(
|
||||
"ksk_keystore");
|
||||
auto bskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey>>(
|
||||
"bsk_keystore");
|
||||
KeyWrapper<LweKeyswitchKey> kskw = kskFut.get();
|
||||
KeyWrapper<LweBootstrapKey> bskw = bskFut.get();
|
||||
context = new mlir::concretelang::RuntimeContext(
|
||||
EvaluationKeys(kskw.keys, bskw.keys, {}));
|
||||
ServerKeyset{bskw.keys, kskw.keys, {}});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,13 @@ typedef enum stream_type {
|
||||
TS_STREAM_TYPE_TOPO_TO_X86_LSAP,
|
||||
TS_STREAM_TYPE_X86_TO_X86_LSAP
|
||||
} stream_type;
|
||||
template <size_t N> struct MemRefDescriptor {
|
||||
uint64_t *allocated;
|
||||
uint64_t *aligned;
|
||||
size_t offset;
|
||||
size_t sizes[N];
|
||||
size_t strides[N];
|
||||
};
|
||||
|
||||
extern "C" {
|
||||
void *stream_emulator_init();
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
|
||||
#define CONCRETELANG_TESTLIB_DYNAMIC_MODULE_H
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
using concretelang::clientlib::ClientParameters;
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class DynamicModule {
|
||||
public:
|
||||
~DynamicModule();
|
||||
|
||||
static outcome::checked<std::shared_ptr<DynamicModule>, StringError>
|
||||
open(std::string outputPath);
|
||||
|
||||
private:
|
||||
outcome::checked<void, StringError>
|
||||
loadClientParametersJSON(std::string outputPath);
|
||||
|
||||
outcome::checked<void, StringError> loadSharedLibrary(std::string outputPath);
|
||||
|
||||
private:
|
||||
std::vector<ClientParameters> clientParametersList;
|
||||
void *libraryHandle;
|
||||
|
||||
friend class ServerLambda;
|
||||
};
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
#endif
|
||||
@@ -1,63 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
|
||||
#define CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/ServerLib/DynamicModule.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
using concretelang::clientlib::ScalarOrTensorData;
|
||||
|
||||
/// ServerLambda is a utility class that allows to call a function of a
|
||||
/// compilation result.
|
||||
class ServerLambda {
|
||||
|
||||
public:
|
||||
/// Load the symbol `funcName` from the shared lib in the artifacts folder
|
||||
/// located in `outputPath`
|
||||
static outcome::checked<ServerLambda, concretelang::error::StringError>
|
||||
load(std::string funcName, std::string outputPath);
|
||||
|
||||
/// Load the symbol `funcName` of the dynamic loaded library
|
||||
static outcome::checked<ServerLambda, concretelang::error::StringError>
|
||||
loadFromModule(std::shared_ptr<DynamicModule> module, std::string funcName);
|
||||
|
||||
/// Call the ServerLambda with public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
call(clientlib::PublicArguments &args,
|
||||
std::optional<clientlib::EvaluationKeys> evaluationKeys,
|
||||
bool simulation = false);
|
||||
|
||||
/// \brief Call the loaded function using opaque pointers to both inputs and
|
||||
/// outputs.
|
||||
/// \param args Array containing pointers to inputs first, followed by
|
||||
/// pointers to outputs.
|
||||
/// \return Error if failed, success otherwise.
|
||||
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
|
||||
|
||||
protected:
|
||||
ClientParameters clientParameters;
|
||||
/// holds a pointer to the entrypoint of the shared lib which
|
||||
void (*func)(void *...);
|
||||
/// Retain module and open shared lib alive
|
||||
std::shared_ptr<DynamicModule> module;
|
||||
};
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,102 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
|
||||
#define CONCRETELANG_SERVERLIB_SERVER_LAMBDA_H
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Common/Transformers.h"
|
||||
#include "concretelang/Common/Values.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include <cassert>
|
||||
#include <dlfcn.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
using concretelang::keysets::ServerKeyset;
|
||||
using concretelang::transformers::ArgTransformer;
|
||||
using concretelang::transformers::ReturnTransformer;
|
||||
using concretelang::transformers::TransformerFactory;
|
||||
using concretelang::values::Value;
|
||||
|
||||
namespace concretelang {
|
||||
namespace serverlib {
|
||||
|
||||
/// A smart pointer to a dynamic module.
|
||||
class DynamicModule {
|
||||
friend class ServerCircuit;
|
||||
|
||||
public:
|
||||
~DynamicModule();
|
||||
static Result<std::shared_ptr<DynamicModule>>
|
||||
open(const std::string &outputPath);
|
||||
|
||||
private:
|
||||
void *libraryHandle;
|
||||
};
|
||||
|
||||
class ServerCircuit {
|
||||
friend class ServerProgram;
|
||||
|
||||
public:
|
||||
/// Call the circuit with public arguments.
|
||||
Result<std::vector<TransportValue>> call(const ServerKeyset &serverKeyset,
|
||||
std::vector<TransportValue> &args);
|
||||
|
||||
Result<std::vector<TransportValue>>
|
||||
simulate(std::vector<TransportValue> &args);
|
||||
|
||||
/// Returns the name of this circuit.
|
||||
std::string getName();
|
||||
|
||||
private:
|
||||
ServerCircuit() = default;
|
||||
|
||||
static Result<ServerCircuit>
|
||||
fromDynamicModule(const Message<concreteprotocol::CircuitInfo> &circuitInfo,
|
||||
std::shared_ptr<DynamicModule> dynamicModule,
|
||||
bool useSimulation);
|
||||
|
||||
void invoke(const ServerKeyset &serverKeyset);
|
||||
|
||||
Message<concreteprotocol::CircuitInfo> circuitInfo;
|
||||
bool useSimulation;
|
||||
void (*func)(void *...);
|
||||
std::shared_ptr<DynamicModule> dynamicModule;
|
||||
std::vector<ArgTransformer> argTransformers;
|
||||
std::vector<ReturnTransformer> returnTransformers;
|
||||
std::vector<Value> argsBuffer;
|
||||
std::vector<Value> returnsBuffer;
|
||||
std::vector<size_t> argDescriptorSizes;
|
||||
std::vector<size_t> returnDescriptorSizes;
|
||||
size_t argRawSize;
|
||||
size_t returnRawSize;
|
||||
};
|
||||
|
||||
/// ServerProgram contains multiple
|
||||
class ServerProgram {
|
||||
public:
|
||||
/// Loads a server program from a shared lib path essentially.
|
||||
static Result<ServerProgram>
|
||||
load(const Message<concreteprotocol::ProgramInfo> &programInfo,
|
||||
const std::string &outputPath, bool useSimulation);
|
||||
|
||||
Result<ServerCircuit> getServerCircuit(const std::string &circuitName);
|
||||
|
||||
private:
|
||||
ServerProgram() = default;
|
||||
|
||||
std::vector<ServerCircuit> serverCircuits;
|
||||
};
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,31 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
|
||||
#define CONCRETELANG_SUPPORT_CLIENTPARAMETERS_H_
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/Support/Encodings.h"
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using ::concretelang::clientlib::ChunkInfo;
|
||||
using ::concretelang::clientlib::ClientParameters;
|
||||
|
||||
llvm::Expected<ClientParameters>
|
||||
createClientParametersFromTFHE(mlir::ModuleOp module,
|
||||
llvm::StringRef functionName, int bitsOfSecurity,
|
||||
encodings::CircuitEncodings encodings,
|
||||
std::optional<CRTDecomposition> maybeCrt);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -9,9 +9,15 @@
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "boost/outcome.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
|
||||
namespace protocol = concreteprotocol;
|
||||
using concretelang::protocol::Message;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -75,9 +81,8 @@ struct CompilationFeedback {
|
||||
/// @brief memory usage per location
|
||||
std::map<std::string, int64_t> memoryUsagePerLoc;
|
||||
|
||||
/// Fill the sizes from the client parameters.
|
||||
void
|
||||
fillFromClientParameters(::concretelang::clientlib::ClientParameters params);
|
||||
/// Fill the sizes from the program info.
|
||||
void fillFromProgramInfo(const Message<protocol::ProgramInfo> ¶ms);
|
||||
|
||||
/// Load the compilation feedback from a path
|
||||
static outcome::checked<CompilationFeedback, StringError>
|
||||
@@ -85,6 +90,7 @@ struct CompilationFeedback {
|
||||
};
|
||||
|
||||
llvm::json::Value toJSON(const mlir::concretelang::CompilationFeedback &);
|
||||
|
||||
bool fromJSON(const llvm::json::Value,
|
||||
mlir::concretelang::CompilationFeedback &, llvm::json::Path);
|
||||
|
||||
|
||||
@@ -6,15 +6,22 @@
|
||||
#ifndef CONCRETELANG_SUPPORT_COMPILER_ENGINE_H
|
||||
#define CONCRETELANG_SUPPORT_COMPILER_ENGINE_H
|
||||
|
||||
#include <concretelang/Conversion/Utils/GlobalFHEContext.h>
|
||||
#include <concretelang/Support/ClientParametersGeneration.h>
|
||||
#include <concretelang/Support/Encodings.h>
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <llvm/Support/Error.h>
|
||||
#include <llvm/Support/SourceMgr.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/MLIRContext.h>
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include "capnp/message.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Support/Encodings.h"
|
||||
#include "concretelang/Support/ProgramInfoGeneration.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
using concretelang::protocol::Message;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -71,7 +78,7 @@ struct CompilationOptions {
|
||||
bool emitGPUOps;
|
||||
std::optional<std::vector<int64_t>> fhelinalgTileSizes;
|
||||
|
||||
std::optional<std::string> clientParametersFuncName;
|
||||
std::optional<std::string> mainFuncName;
|
||||
|
||||
optimizer::Config optimizerConfig;
|
||||
|
||||
@@ -84,7 +91,7 @@ struct CompilationOptions {
|
||||
|
||||
/// When compiling from a dialect lower than FHE, one needs to provide
|
||||
/// encodings info manually to allow the client lib to be generated.
|
||||
std::optional<mlir::concretelang::encodings::CircuitEncodings> encodings;
|
||||
std::optional<Message<concreteprotocol::CircuitEncodingInfo>> encodings;
|
||||
|
||||
CompilationOptions()
|
||||
: v0FHEConstraints(std::nullopt), verifyDiagnostics(false),
|
||||
@@ -92,12 +99,12 @@ struct CompilationOptions {
|
||||
maxBatchSize(std::numeric_limits<int64_t>::max()), emitSDFGOps(false),
|
||||
unrollLoopsWithSDFGConvertibleOps(false), dataflowParallelize(false),
|
||||
optimizeTFHE(true), simulate(false), emitGPUOps(false),
|
||||
clientParametersFuncName(std::nullopt),
|
||||
optimizerConfig(optimizer::DEFAULT_CONFIG), chunkIntegers(false),
|
||||
chunkSize(4), chunkWidth(2), encodings(std::nullopt){};
|
||||
mainFuncName(std::nullopt), optimizerConfig(optimizer::DEFAULT_CONFIG),
|
||||
chunkIntegers(false), chunkSize(4), chunkWidth(2),
|
||||
encodings(std::nullopt){};
|
||||
|
||||
CompilationOptions(std::string funcname) : CompilationOptions() {
|
||||
clientParametersFuncName = funcname;
|
||||
mainFuncName = funcname;
|
||||
}
|
||||
|
||||
/// @brief Constructor for CompilationOptions with default parameters for a
|
||||
@@ -130,7 +137,7 @@ public:
|
||||
: compilationContext(compilationContext) {}
|
||||
|
||||
std::optional<mlir::OwningOpRef<mlir::ModuleOp>> mlirModuleRef;
|
||||
std::optional<mlir::concretelang::ClientParameters> clientParameters;
|
||||
std::optional<Message<concreteprotocol::ProgramInfo>> programInfo;
|
||||
std::optional<CompilationFeedback> feedback;
|
||||
std::unique_ptr<llvm::Module> llvmModule;
|
||||
std::optional<mlir::concretelang::V0FHEContext> fheContext;
|
||||
@@ -142,12 +149,11 @@ public:
|
||||
class Library {
|
||||
std::string outputDirPath;
|
||||
std::vector<std::string> objectsPath;
|
||||
std::vector<mlir::concretelang::ClientParameters> clientParametersList;
|
||||
std::vector<mlir::concretelang::CompilationFeedback>
|
||||
compilationFeedbackList;
|
||||
/// Path to the runtime library. Will be linked to the output library if set
|
||||
std::string runtimeLibraryPath;
|
||||
bool cleanUp;
|
||||
mlir::concretelang::CompilationFeedback compilationFeedback;
|
||||
Message<concreteprotocol::ProgramInfo> programInfo;
|
||||
|
||||
public:
|
||||
/// Create a library instance on which you can add compilation results.
|
||||
@@ -156,26 +162,32 @@ public:
|
||||
Library(std::string outputDirPath, std::string runtimeLibraryPath = "",
|
||||
bool cleanUp = true)
|
||||
: outputDirPath(outputDirPath), runtimeLibraryPath(runtimeLibraryPath),
|
||||
cleanUp(cleanUp) {}
|
||||
/// Add a compilation result to the library
|
||||
llvm::Expected<std::string> addCompilation(CompilationResult &compilation);
|
||||
cleanUp(cleanUp), programInfo() {}
|
||||
/// Sets the compilation result used by the library
|
||||
llvm::Expected<std::string>
|
||||
setCompilationResult(CompilationResult &compilation);
|
||||
/// Emit the library artifacts with the previously added compilation result
|
||||
llvm::Error emitArtifacts(bool sharedLib, bool staticLib,
|
||||
bool clientParameters, bool compilationFeedback,
|
||||
bool cppHeader);
|
||||
bool clientParameters, bool compilationFeedback);
|
||||
/// After a shared library has been emitted, its path is here
|
||||
std::string sharedLibraryPath;
|
||||
/// After a static library has been emitted, its path is here
|
||||
std::string staticLibraryPath;
|
||||
|
||||
/// Returns the program info of the library.
|
||||
Message<concreteprotocol::ProgramInfo> getProgramInfo() const;
|
||||
|
||||
/// Returns the path to the output dir.
|
||||
const std::string &getOutputDirPath() const;
|
||||
|
||||
/// Returns the path of the shared library
|
||||
static std::string getSharedLibraryPath(std::string outputDirPath);
|
||||
|
||||
/// Returns the path of the static library
|
||||
static std::string getStaticLibraryPath(std::string outputDirPath);
|
||||
|
||||
/// Returns the path of the client parameters
|
||||
static std::string getClientParametersPath(std::string outputDirPath);
|
||||
/// Returns the path of the program info
|
||||
static std::string getProgramInfoPath(std::string outputDirPath);
|
||||
|
||||
/// Returns the path of the compilation feedback
|
||||
static std::string getCompilationFeedbackPath(std::string outputDirPath);
|
||||
@@ -194,12 +206,10 @@ public:
|
||||
llvm::Expected<std::string> emitStatic();
|
||||
/// Emit a shared library with the previously added compilation result
|
||||
llvm::Expected<std::string> emitShared();
|
||||
/// Emit a json ClientParameters corresponding to library content
|
||||
llvm::Expected<std::string> emitClientParametersJSON();
|
||||
/// Emit a json ProgramInfo corresponding to library content
|
||||
llvm::Expected<std::string> emitProgramInfoJSON();
|
||||
/// Emit a json CompilationFeedback corresponding to library content
|
||||
llvm::Expected<std::string> emitCompilationFeedbackJSON();
|
||||
/// Emit a client header file for this corresponding to library content
|
||||
llvm::Expected<std::string> emitCppHeader();
|
||||
};
|
||||
|
||||
/// Specification of the exit stage of the compilation pipeline
|
||||
@@ -265,8 +275,7 @@ public:
|
||||
|
||||
CompilerEngine(std::shared_ptr<CompilationContext> compilationContext)
|
||||
: overrideMaxEintPrecision(), overrideMaxMANP(), compilerOptions(),
|
||||
generateClientParameters(
|
||||
compilerOptions.clientParametersFuncName.has_value()),
|
||||
generateProgramInfo(compilerOptions.mainFuncName.has_value()),
|
||||
enablePass([](mlir::Pass *pass) { return true; }),
|
||||
compilationContext(compilationContext) {}
|
||||
|
||||
@@ -290,8 +299,7 @@ public:
|
||||
compile(std::vector<std::string> inputs, std::string outputDirPath,
|
||||
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
|
||||
bool generateStaticLib = true, bool generateClientParameters = true,
|
||||
bool generateCompilationFeedback = true,
|
||||
bool generateCppHeader = true);
|
||||
bool generateCompilationFeedback = true);
|
||||
|
||||
/// Compile and emit artifact to the given outputDirPath from an LLVM source
|
||||
/// manager.
|
||||
@@ -299,38 +307,38 @@ public:
|
||||
compile(llvm::SourceMgr &sm, std::string outputDirPath,
|
||||
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
|
||||
bool generateStaticLib = true, bool generateClientParameters = true,
|
||||
bool generateCompilationFeedback = true,
|
||||
bool generateCppHeader = true);
|
||||
bool generateCompilationFeedback = true);
|
||||
|
||||
llvm::Expected<CompilerEngine::Library>
|
||||
compile(mlir::ModuleOp module, std::string outputDirPath,
|
||||
std::string runtimeLibraryPath = "", bool generateSharedLib = true,
|
||||
bool generateStaticLib = true, bool generateClientParameters = true,
|
||||
bool generateCompilationFeedback = true,
|
||||
bool generateCppHeader = true);
|
||||
bool generateCompilationFeedback = true);
|
||||
|
||||
void setCompilationOptions(CompilationOptions &options) {
|
||||
compilerOptions = options;
|
||||
if (options.v0FHEConstraints.has_value()) {
|
||||
setFHEConstraints(*options.v0FHEConstraints);
|
||||
void setCompilationOptions(CompilationOptions options) {
|
||||
compilerOptions = std::move(options);
|
||||
if (compilerOptions.v0FHEConstraints.has_value()) {
|
||||
setFHEConstraints(*compilerOptions.v0FHEConstraints);
|
||||
}
|
||||
|
||||
if (options.clientParametersFuncName.has_value()) {
|
||||
setGenerateClientParameters(true);
|
||||
if (compilerOptions.mainFuncName.has_value()) {
|
||||
setGenerateProgramInfo(true);
|
||||
}
|
||||
}
|
||||
|
||||
CompilationOptions &getCompilationOptions() { return compilerOptions; }
|
||||
|
||||
void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c);
|
||||
void setMaxEintPrecision(size_t v);
|
||||
void setMaxMANP(size_t v);
|
||||
void setGenerateClientParameters(bool v);
|
||||
void setGenerateProgramInfo(bool v);
|
||||
void setEnablePass(std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
protected:
|
||||
std::optional<size_t> overrideMaxEintPrecision;
|
||||
std::optional<size_t> overrideMaxMANP;
|
||||
CompilationOptions compilerOptions;
|
||||
bool generateClientParameters;
|
||||
bool generateProgramInfo;
|
||||
std::function<bool(mlir::Pass *)> enablePass;
|
||||
|
||||
std::shared_ptr<CompilationContext> compilationContext;
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define CONCRETELANG_SUPPORT_ENCODINGS_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@@ -21,118 +22,28 @@
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "capnp/message.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
|
||||
using concretelang::protocol::Message;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace encodings {
|
||||
|
||||
/// Represents the encoding of a small (unchunked) `FHE::eint` type.
|
||||
struct EncryptedIntegerScalarEncoding {
|
||||
uint64_t width;
|
||||
bool isSigned;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, EncryptedIntegerScalarEncoding &,
|
||||
llvm::json::Path);
|
||||
llvm::json::Value toJSON(const EncryptedIntegerScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
EncryptedIntegerScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
llvm::Expected<Message<concreteprotocol::CircuitEncodingInfo>>
|
||||
getCircuitEncodings(llvm::StringRef functionName, mlir::ModuleOp module);
|
||||
|
||||
/// Represents the encoding of a big (chunked) `FHE::eint` type.
|
||||
struct EncryptedChunkedIntegerScalarEncoding {
|
||||
uint64_t width;
|
||||
bool isSigned;
|
||||
uint64_t chunkSize;
|
||||
uint64_t chunkWidth;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, EncryptedChunkedIntegerScalarEncoding &,
|
||||
llvm::json::Path);
|
||||
llvm::json::Value toJSON(const EncryptedChunkedIntegerScalarEncoding &);
|
||||
static inline llvm::raw_ostream &
|
||||
operator<<(llvm::raw_ostream &OS, EncryptedChunkedIntegerScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a `FHE::ebool` type.
|
||||
struct EncryptedBoolScalarEncoding {};
|
||||
bool fromJSON(const llvm::json::Value, EncryptedBoolScalarEncoding &,
|
||||
llvm::json::Path);
|
||||
llvm::json::Value toJSON(const EncryptedBoolScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
EncryptedBoolScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a builtin integer type.
|
||||
struct PlaintextScalarEncoding {
|
||||
uint64_t width;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, PlaintextScalarEncoding &,
|
||||
llvm::json::Path);
|
||||
llvm::json::Value toJSON(const PlaintextScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
PlaintextScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a builtin index type.
|
||||
struct IndexScalarEncoding {};
|
||||
bool fromJSON(const llvm::json::Value, IndexScalarEncoding &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const IndexScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
IndexScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a scalar value.
|
||||
using ScalarEncoding = std::variant<
|
||||
EncryptedIntegerScalarEncoding, EncryptedChunkedIntegerScalarEncoding,
|
||||
EncryptedBoolScalarEncoding, PlaintextScalarEncoding, IndexScalarEncoding>;
|
||||
bool fromJSON(const llvm::json::Value, ScalarEncoding &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const ScalarEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
ScalarEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of a tensor value.
|
||||
struct TensorEncoding {
|
||||
ScalarEncoding scalarEncoding;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, TensorEncoding &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const TensorEncoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
TensorEncoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encoding of either an input or output value of a circuit.
|
||||
using Encoding = std::variant<TensorEncoding, ScalarEncoding>;
|
||||
bool fromJSON(const llvm::json::Value, Encoding &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const Encoding &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, Encoding e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
/// Represents the encodings of a circuit.
|
||||
struct CircuitEncodings {
|
||||
std::vector<Encoding> inputEncodings;
|
||||
std::vector<Encoding> outputEncodings;
|
||||
};
|
||||
bool fromJSON(const llvm::json::Value, CircuitEncodings &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const CircuitEncodings &);
|
||||
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
CircuitEncodings e) {
|
||||
return OS << llvm::formatv("{0:2}", toJSON(e));
|
||||
}
|
||||
|
||||
llvm::Expected<CircuitEncodings> getCircuitEncodings(
|
||||
llvm::StringRef functionName, mlir::ModuleOp module,
|
||||
std::optional<::concretelang::clientlib::ChunkInfo> maybeChunkInfo);
|
||||
void setCircuitEncodingModes(
|
||||
Message<concreteprotocol::CircuitEncodingInfo> &info,
|
||||
std::optional<
|
||||
Message<concreteprotocol::IntegerCiphertextEncodingInfo::ChunkedMode>>
|
||||
maybeChunk,
|
||||
std::optional<V0FHEContext> maybeFheContext);
|
||||
|
||||
} // namespace encodings
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_JIT_SUPPORT
|
||||
#define CONCRETELANG_SUPPORT_JIT_SUPPORT
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/LambdaSupport.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
|
||||
/// JitCompilationResult is the result of a Jit compilation, the server JIT
|
||||
/// lambda and the clientParameters.
|
||||
struct JitCompilationResult {
|
||||
std::shared_ptr<concretelang::JITLambda> lambda;
|
||||
clientlib::ClientParameters clientParameters;
|
||||
CompilationFeedback feedback;
|
||||
};
|
||||
|
||||
/// JITSupport is the instantiated LambdaSupport for the Jit Compilation.
|
||||
class JITSupport
|
||||
: public LambdaSupport<std::shared_ptr<concretelang::JITLambda>,
|
||||
JitCompilationResult> {
|
||||
|
||||
public:
|
||||
JITSupport(std::optional<std::string> runtimeLibPath = std::nullopt);
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, CompilationOptions options) override;
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
compile(mlir::ModuleOp program,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
|
||||
CompilationOptions options) override;
|
||||
using LambdaSupport::compile;
|
||||
|
||||
llvm::Expected<std::shared_ptr<concretelang::JITLambda>>
|
||||
loadServerLambda(JitCompilationResult &result) override {
|
||||
return result.lambda;
|
||||
}
|
||||
|
||||
llvm::Expected<clientlib::ClientParameters>
|
||||
loadClientParameters(JitCompilationResult &result) override {
|
||||
return result.clientParameters;
|
||||
}
|
||||
|
||||
llvm::Expected<CompilationFeedback>
|
||||
loadCompilationFeedback(JitCompilationResult &result) override {
|
||||
return result.feedback;
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
serverCall(std::shared_ptr<concretelang::JITLambda> lambda,
|
||||
clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys) override {
|
||||
return lambda->call(args, evaluationKeys);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
compileWithEngine(T program, CompilationOptions options,
|
||||
concretelang::CompilerEngine &engine);
|
||||
|
||||
std::optional<std::string> runtimeLibPath;
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,66 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef COMPILER_JIT_H
|
||||
#define COMPILER_JIT_H
|
||||
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
|
||||
#include <concretelang/ClientLib/KeySet.h>
|
||||
#include <concretelang/ClientLib/PublicArguments.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
using ::concretelang::clientlib::CircuitGate;
|
||||
using ::concretelang::clientlib::KeySet;
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
|
||||
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
|
||||
/// of the module.
|
||||
class JITLambda {
|
||||
public:
|
||||
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
|
||||
: type(type), name(name){};
|
||||
|
||||
/// create a JITLambda that point to the function name of the given module.
|
||||
/// Use runtimeLibPath as a shared library if specified.
|
||||
static llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline,
|
||||
std::optional<std::string> runtimeLibPath = {});
|
||||
|
||||
/// Call the JIT lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
call(clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys);
|
||||
|
||||
void setUseDataflow(bool option) { this->useDataflow = option; }
|
||||
|
||||
/// invokeRaw execute the jit lambda with a list of Argument, the last one is
|
||||
/// used to store the result of the computation.
|
||||
/// Example:
|
||||
/// uin64_t arg0 = 1;
|
||||
/// uin64_t res;
|
||||
/// llvm::SmallVector<void *> args{&arg1, &res};
|
||||
/// lambda.invokeRaw(args);
|
||||
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
|
||||
|
||||
private:
|
||||
mlir::LLVM::LLVMFunctionType type;
|
||||
std::string name;
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
/// Tell if the DF parallelization was on or during compilation. This will be
|
||||
/// useful to abort execution if the runtime doesn't support dataflow
|
||||
/// execution, instead of having undefined symbol issues
|
||||
bool useDataflow = false;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif // COMPILER_JIT_H
|
||||
@@ -1,304 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_LAMBDA_ARGUMENT_H
|
||||
#define CONCRETELANG_SUPPORT_LAMBDA_ARGUMENT_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <llvm/ADT/ArrayRef.h>
|
||||
#include <llvm/Support/Casting.h>
|
||||
#include <llvm/Support/ExtensibleRTTI.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
/// Abstract base class for lambda arguments
|
||||
class LambdaArgument
|
||||
: public llvm::RTTIExtends<LambdaArgument, llvm::RTTIRoot> {
|
||||
public:
|
||||
LambdaArgument(LambdaArgument &) = delete;
|
||||
|
||||
template <typename T> bool isa() const { return llvm::isa<T>(*this); }
|
||||
|
||||
/// Cast functions on constant instances
|
||||
template <typename T> const T &cast() const { return llvm::cast<T>(*this); }
|
||||
template <typename T> const T *dyn_cast() const {
|
||||
return llvm::dyn_cast<T>(this);
|
||||
}
|
||||
|
||||
/// Cast functions for mutable instances
|
||||
template <typename T> T &cast() { return llvm::cast<T>(*this); }
|
||||
template <typename T> T *dyn_cast() { return llvm::dyn_cast<T>(this); }
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
LambdaArgument(){};
|
||||
};
|
||||
|
||||
/// Class for integer arguments. `BackingIntType` is used as the data
|
||||
/// type to hold the argument's value. The precision is the actual
|
||||
/// precision of the value, which might be different from the precision
|
||||
/// of the backing integer type.
|
||||
template <typename BackingIntType = uint64_t>
|
||||
class IntLambdaArgument
|
||||
: public llvm::RTTIExtends<IntLambdaArgument<BackingIntType>,
|
||||
LambdaArgument> {
|
||||
public:
|
||||
typedef BackingIntType value_type;
|
||||
|
||||
IntLambdaArgument(BackingIntType value,
|
||||
unsigned int precision = 8 * sizeof(BackingIntType))
|
||||
: precision(precision) {
|
||||
if (precision < 8 * sizeof(BackingIntType)) {
|
||||
this->value = value & (1 << (this->precision - 1));
|
||||
} else {
|
||||
this->value = value;
|
||||
}
|
||||
}
|
||||
|
||||
unsigned int getPrecision() const { return this->precision; }
|
||||
BackingIntType getValue() const { return this->value; }
|
||||
|
||||
template <typename OtherBackingIntType>
|
||||
bool operator==(const IntLambdaArgument<OtherBackingIntType> &other) const {
|
||||
return getValue() == other.getValue();
|
||||
}
|
||||
|
||||
template <typename OtherBackingIntType>
|
||||
bool operator!=(const IntLambdaArgument<OtherBackingIntType> &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
unsigned int precision;
|
||||
BackingIntType value;
|
||||
};
|
||||
|
||||
template <typename BackingIntType>
|
||||
char IntLambdaArgument<BackingIntType>::ID = 0;
|
||||
|
||||
/// Class for encrypted integer arguments. `BackingIntType` is used as
|
||||
/// the data type to hold the argument's plaintext value. The precision
|
||||
/// is the actual precision of the value, which might be different from
|
||||
/// the precision of the backing integer type.
|
||||
template <typename BackingIntType = uint64_t>
|
||||
class EIntLambdaArgument
|
||||
: public llvm::RTTIExtends<EIntLambdaArgument<BackingIntType>,
|
||||
IntLambdaArgument<BackingIntType>> {
|
||||
public:
|
||||
static char ID;
|
||||
};
|
||||
|
||||
template <typename BackingIntType>
|
||||
char EIntLambdaArgument<BackingIntType>::ID = 0;
|
||||
|
||||
namespace {
|
||||
/// Calculates `accu *= factor` or returns an error if the result
|
||||
/// would overflow
|
||||
template <typename AccuT, typename ValT>
|
||||
llvm::Error safeUnsignedMul(AccuT &accu, ValT factor) {
|
||||
static_assert(std::numeric_limits<AccuT>::is_integer &&
|
||||
std::numeric_limits<ValT>::is_integer &&
|
||||
!std::numeric_limits<AccuT>::is_signed &&
|
||||
!std::numeric_limits<ValT>::is_signed,
|
||||
"Only unsigned integers are supported");
|
||||
|
||||
const AccuT left = std::numeric_limits<AccuT>::max() / accu;
|
||||
|
||||
if (left > factor) {
|
||||
accu *= factor;
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
return StreamStringError("Multiplying value ")
|
||||
<< accu << " with " << factor << " would cause an overflow";
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/// Class for Tensor arguments. This can either be plaintext tensors
|
||||
/// (for `ScalarArgumentT = IntLambaArgument<T>`) or tensors
|
||||
/// representing encrypted integers (for `ScalarArgumentT =
|
||||
/// EIntLambaArgument<T>`).
|
||||
template <typename ScalarArgumentT>
|
||||
class TensorLambdaArgument
|
||||
: public llvm::RTTIExtends<TensorLambdaArgument<ScalarArgumentT>,
|
||||
LambdaArgument> {
|
||||
public:
|
||||
typedef ScalarArgumentT scalar_type;
|
||||
|
||||
/// Construct tensor argument from the one-dimensional array `value`,
|
||||
/// but interpreting the array's values as a linearized
|
||||
/// multi-dimensional tensor with the sizes of the dimensions
|
||||
/// specified in `dimensions`.
|
||||
TensorLambdaArgument(
|
||||
llvm::ArrayRef<typename ScalarArgumentT::value_type> value,
|
||||
llvm::ArrayRef<int64_t> dimensions)
|
||||
: dimensions(dimensions.vec()) {
|
||||
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
|
||||
}
|
||||
|
||||
/// Construct tensor argument by moving the values from the
|
||||
/// one-dimensional vector `value`, but interpreting the array's
|
||||
/// values as a linearized multi-dimensional tensor with the sizes
|
||||
/// of the dimensions specified in `dimensions`.
|
||||
TensorLambdaArgument(
|
||||
std::vector<typename ScalarArgumentT::value_type> &&value,
|
||||
llvm::ArrayRef<int64_t> dimensions)
|
||||
: dimensions(dimensions.vec()), value(std::move(value)) {}
|
||||
|
||||
/// Construct a one-dimensional tensor argument from the
|
||||
/// array `value`.
|
||||
TensorLambdaArgument(
|
||||
llvm::ArrayRef<typename ScalarArgumentT::value_type> value)
|
||||
: TensorLambdaArgument(value, {(int64_t)value.size()}) {}
|
||||
|
||||
template <std::size_t size1, std::size_t size2>
|
||||
TensorLambdaArgument(
|
||||
typename ScalarArgumentT::value_type (&a)[size1][size2]) {
|
||||
dimensions = {size1, size2};
|
||||
auto value = llvm::MutableArrayRef<typename ScalarArgumentT::value_type>(
|
||||
(typename ScalarArgumentT::value_type *)a, size1 * size2);
|
||||
std::copy(value.begin(), value.end(), std::back_inserter(this->value));
|
||||
}
|
||||
|
||||
const std::vector<int64_t> &getDimensions() const { return this->dimensions; }
|
||||
|
||||
/// Returns the total number of elements in the tensor. If the number
|
||||
/// of elements cannot be represented as a `size_t`, the method
|
||||
/// returns an error.
|
||||
llvm::Expected<size_t> getNumElements() const {
|
||||
size_t accu = 1;
|
||||
|
||||
for (unsigned int dimSize : dimensions)
|
||||
if (llvm::Error err = safeUnsignedMul(accu, dimSize))
|
||||
return std::move(err);
|
||||
|
||||
return accu;
|
||||
}
|
||||
|
||||
/// Returns a bare pointer to the linearized values of the tensor
|
||||
/// (constant version).
|
||||
const typename ScalarArgumentT::value_type *getValue() const {
|
||||
return this->value.data();
|
||||
}
|
||||
|
||||
/// Returns a bare pointer to the linearized values of the tensor (mutable
|
||||
/// version).
|
||||
typename ScalarArgumentT::value_type *getValue() {
|
||||
return this->value.data();
|
||||
}
|
||||
|
||||
template <typename OtherScalarArgumentT>
|
||||
bool
|
||||
operator==(const TensorLambdaArgument<OtherScalarArgumentT> &other) const {
|
||||
if (getDimensions() != other.getDimensions())
|
||||
return false;
|
||||
|
||||
for (auto pair : llvm::zip(value, other.value)) {
|
||||
if (std::get<0>(pair) != std::get<1>(pair))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename OtherScalarArgumentT>
|
||||
bool
|
||||
operator!=(const TensorLambdaArgument<OtherScalarArgumentT> &other) const {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
static char ID;
|
||||
|
||||
protected:
|
||||
std::vector<typename ScalarArgumentT::value_type> value;
|
||||
std::vector<int64_t> dimensions;
|
||||
};
|
||||
|
||||
template <typename ScalarArgumentT>
|
||||
char TensorLambdaArgument<ScalarArgumentT>::ID = 0;
|
||||
|
||||
namespace {
|
||||
template <typename T> struct NameOfFundamentalType {
|
||||
static const char *get();
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<uint8_t> {
|
||||
static const char *get() { return "uint8_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<int8_t> {
|
||||
static const char *get() { return "int8_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<uint16_t> {
|
||||
static const char *get() { return "uint16_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<int16_t> {
|
||||
static const char *get() { return "int16_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<uint32_t> {
|
||||
static const char *get() { return "uint32_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<int32_t> {
|
||||
static const char *get() { return "int32_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<uint64_t> {
|
||||
static const char *get() { return "uint64_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<int64_t> {
|
||||
static const char *get() { return "int64_t"; }
|
||||
};
|
||||
|
||||
template <typename... Ts> struct LambdaArgumentTypeName;
|
||||
|
||||
template <> struct LambdaArgumentTypeName<> {
|
||||
static const char *get(const mlir::concretelang::LambdaArgument &arg) {
|
||||
assert(false && "No name implemented for this lambda argument type");
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts> struct LambdaArgumentTypeName<T, Ts...> {
|
||||
static const std::string get(const mlir::concretelang::LambdaArgument &arg) {
|
||||
if (arg.dyn_cast<const IntLambdaArgument<T>>()) {
|
||||
return NameOfFundamentalType<T>::get();
|
||||
} else if (arg.dyn_cast<const EIntLambdaArgument<T>>()) {
|
||||
return std::string("encrypted ") + NameOfFundamentalType<T>::get();
|
||||
} else if (arg.dyn_cast<
|
||||
const TensorLambdaArgument<IntLambdaArgument<T>>>()) {
|
||||
return std::string("tensor<") + NameOfFundamentalType<T>::get() + ">";
|
||||
} else if (arg.dyn_cast<
|
||||
const TensorLambdaArgument<EIntLambdaArgument<T>>>()) {
|
||||
return std::string("tensor<encrypted ") +
|
||||
NameOfFundamentalType<T>::get() + ">";
|
||||
}
|
||||
|
||||
return LambdaArgumentTypeName<Ts...>::get(arg);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static inline const std::string
|
||||
getLambdaArgumentTypeAsString(const LambdaArgument &arg) {
|
||||
return LambdaArgumentTypeName<int8_t, uint8_t, int16_t, uint16_t, int32_t,
|
||||
uint32_t, int64_t, uint64_t>::get(arg);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,541 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_LAMBDASUPPORT
|
||||
#define CONCRETELANG_SUPPORT_LAMBDASUPPORT
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/Support/LambdaArgument.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
|
||||
namespace {
|
||||
// Generic function template as well as specializations of
|
||||
// `typedResult` must be declared at namespace scope due to return
|
||||
// type template specialization
|
||||
|
||||
/// Helper function for implementing type-dependent preparation of the result.
|
||||
template <typename ResT>
|
||||
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result);
|
||||
|
||||
template <typename T>
|
||||
inline llvm::Expected<T> typedScalarResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextScalar<T>(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedResult cannot get clear text scalar")
|
||||
<< clearResult.error().mesg;
|
||||
}
|
||||
|
||||
return clearResult.value();
|
||||
}
|
||||
|
||||
/// Specializations of `typedResult()` for scalar results, forwarding
|
||||
/// scalar value to caller.
|
||||
|
||||
template <>
|
||||
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
return typedScalarResult<uint64_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<int64_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
return typedScalarResult<int64_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<uint32_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
return typedScalarResult<uint32_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<int32_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
return typedScalarResult<int32_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<uint16_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
return typedScalarResult<uint16_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<int16_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
return typedScalarResult<int16_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<uint8_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
return typedScalarResult<uint8_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<int8_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
return typedScalarResult<int8_t>(keySet, result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline llvm::Expected<std::vector<T>>
|
||||
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector<T>(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedVectorResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
}
|
||||
return std::move(clearResult.value());
|
||||
}
|
||||
|
||||
/// Specializations of `typedResult()` for vector results, initializing
|
||||
/// an `std::vector` of the right size with the results and forwarding
|
||||
/// it to the caller with move semantics.
|
||||
/// Cannot factor out into a template template <typename T> inline
|
||||
/// llvm::Expected<std::vector<uint8_t>>
|
||||
/// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
|
||||
/// to ambiguity with scalar template
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint8_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint8_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<int8_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<int8_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint16_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint16_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<int16_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<int16_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint32_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint32_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<int32_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<int32_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint64_t>(keySet, result);
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<int64_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<int64_t>(keySet, result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
buildTensorLambdaResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
llvm::Expected<std::vector<T>> tensorOrError =
|
||||
typedResult<std::vector<T>>(keySet, result);
|
||||
if (auto err = tensorOrError.takeError())
|
||||
return std::move(err);
|
||||
|
||||
auto tensorDim = result.asClearTextShape(0);
|
||||
if (tensorDim.has_error())
|
||||
return StreamStringError(tensorDim.error().mesg);
|
||||
|
||||
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
|
||||
*tensorOrError, tensorDim.value());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
buildScalarLambdaResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
llvm::Expected<T> scalarOrError = typedResult<T>(keySet, result);
|
||||
if (auto err = scalarOrError.takeError())
|
||||
return std::move(err);
|
||||
|
||||
return std::make_unique<IntLambdaArgument<T>>(*scalarOrError);
|
||||
}
|
||||
|
||||
/// pecialization of `typedResult()` for a single result wrapped into
|
||||
/// a `LambdaArgument`.
|
||||
template <>
|
||||
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto gate = keySet.outputGate(0);
|
||||
auto width = gate.shape.width;
|
||||
bool sign = gate.shape.sign;
|
||||
|
||||
if (width > 64)
|
||||
return StreamStringError("Cannot handle values with more than 64 bits");
|
||||
|
||||
// By convention, decrypted integers are always 64 bits wide
|
||||
if (gate.isEncrypted())
|
||||
width = 64;
|
||||
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
// scalar case
|
||||
if (width > 32) {
|
||||
return (sign) ? buildScalarLambdaResult<int64_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint64_t>(keySet, result);
|
||||
} else if (width > 16) {
|
||||
return (sign) ? buildScalarLambdaResult<int32_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint32_t>(keySet, result);
|
||||
} else if (width > 8) {
|
||||
return (sign) ? buildScalarLambdaResult<int16_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint16_t>(keySet, result);
|
||||
} else if (width <= 8) {
|
||||
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint8_t>(keySet, result);
|
||||
}
|
||||
} else if (gate.chunkInfo.has_value()) {
|
||||
// chunked scalar case
|
||||
assert(gate.shape.dimensions.size() == 1);
|
||||
width = gate.shape.size * gate.chunkInfo->width;
|
||||
if (width > 32) {
|
||||
return (sign) ? buildScalarLambdaResult<int64_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint64_t>(keySet, result);
|
||||
} else if (width > 16) {
|
||||
return (sign) ? buildScalarLambdaResult<int32_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint32_t>(keySet, result);
|
||||
} else if (width > 8) {
|
||||
return (sign) ? buildScalarLambdaResult<int16_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint16_t>(keySet, result);
|
||||
} else if (width <= 8) {
|
||||
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
|
||||
: buildScalarLambdaResult<uint8_t>(keySet, result);
|
||||
}
|
||||
} else {
|
||||
// tensor case
|
||||
if (width > 32) {
|
||||
return (sign) ? buildTensorLambdaResult<int64_t>(keySet, result)
|
||||
: buildTensorLambdaResult<uint64_t>(keySet, result);
|
||||
} else if (width > 16) {
|
||||
return (sign) ? buildTensorLambdaResult<int32_t>(keySet, result)
|
||||
: buildTensorLambdaResult<uint32_t>(keySet, result);
|
||||
} else if (width > 8) {
|
||||
return (sign) ? buildTensorLambdaResult<int16_t>(keySet, result)
|
||||
: buildTensorLambdaResult<uint16_t>(keySet, result);
|
||||
} else if (width <= 8) {
|
||||
return (sign) ? buildTensorLambdaResult<int8_t>(keySet, result)
|
||||
: buildTensorLambdaResult<uint8_t>(keySet, result);
|
||||
}
|
||||
}
|
||||
|
||||
assert(false && "Cannot happen");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/// Adaptor class that push arguments specified as instances of
|
||||
/// `LambdaArgument` to `clientlib::EncryptedArguments`.
|
||||
class LambdaArgumentAdaptor {
|
||||
public:
|
||||
/// Checks if the argument `arg` is an plaintext / encrypted integer
|
||||
/// argument or a plaintext / encrypted tensor argument with a
|
||||
/// backing integer type `IntT` and push the argument to `encryptedArgs`.
|
||||
///
|
||||
/// Returns `true` if `arg` has one of the types above and its value
|
||||
/// was successfully added to `encryptedArgs`, `false` if none of the types
|
||||
/// matches or an error if a type matched, but adding the argument to
|
||||
/// `encryptedArgs` failed.
|
||||
template <typename IntT>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
|
||||
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
|
||||
if (!res.has_value()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
} else if (auto tla = arg.dyn_cast<
|
||||
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
|
||||
auto res =
|
||||
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
|
||||
if (!res.has_value()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Recursive case for `tryAddArg<IntT>(...)`
|
||||
template <typename IntT, typename NextIntT, typename... IntTs>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
llvm::Expected<bool> successOrError =
|
||||
tryAddArg<IntT>(encryptedArgs, arg, keySet);
|
||||
|
||||
if (!successOrError)
|
||||
return successOrError.takeError();
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
|
||||
/// error if either the argument type is unsupported or if the argument types
|
||||
/// is supported, but adding it to `encryptedArgs` failed.
|
||||
static inline llvm::Error
|
||||
addArgument(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
// Try the supported integer types; size_t needs explicit
|
||||
// treatment, since it may alias none of the fixed size integer
|
||||
// types
|
||||
llvm::Expected<bool> successOrError =
|
||||
LambdaArgumentAdaptor::tryAddArg<int64_t, int32_t, int16_t, int8_t,
|
||||
uint64_t, uint32_t, uint16_t, uint8_t,
|
||||
size_t>(encryptedArgs, arg, keySet);
|
||||
|
||||
if (!successOrError)
|
||||
return successOrError.takeError();
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return StreamStringError("Unknown argument type");
|
||||
else
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
/// Encrypts and build public arguments from lambda arguments
|
||||
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
|
||||
exportArguments(llvm::ArrayRef<const LambdaArgument *> args,
|
||||
clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySet &keySet) {
|
||||
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::empty();
|
||||
for (auto arg : args) {
|
||||
if (auto err = LambdaArgumentAdaptor::addArgument(*encryptedArgs, *arg,
|
||||
keySet)) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
auto check = encryptedArgs->checkAllArgs(keySet);
|
||||
if (check.has_error()) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
auto publicArguments =
|
||||
encryptedArgs->exportPublicArguments(clientParameters);
|
||||
if (publicArguments.has_error()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
return std::move(publicArguments.value());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lambda, typename CompilationResult> class LambdaSupport {
|
||||
public:
|
||||
typedef Lambda lambda;
|
||||
typedef CompilationResult compilationResult;
|
||||
|
||||
virtual ~LambdaSupport() {}
|
||||
|
||||
/// Compile the mlir program and produces a compilation result if succeed.
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
|
||||
llvm::SourceMgr &program,
|
||||
CompilationOptions options = CompilationOptions("main")) = 0;
|
||||
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
|
||||
mlir::ModuleOp program,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
|
||||
CompilationOptions options = CompilationOptions("main")) = 0;
|
||||
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>>
|
||||
compile(llvm::StringRef program,
|
||||
CompilationOptions options = CompilationOptions("main")) {
|
||||
return compile(llvm::MemoryBuffer::getMemBuffer(program), options);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>>
|
||||
compile(std::unique_ptr<llvm::MemoryBuffer> program,
|
||||
CompilationOptions options = CompilationOptions("main")) {
|
||||
llvm::SourceMgr sm;
|
||||
sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc());
|
||||
return compile(sm, options);
|
||||
}
|
||||
|
||||
/// Load the server lambda from the compilation result.
|
||||
llvm::Expected<Lambda> virtual loadServerLambda(
|
||||
CompilationResult &result) = 0;
|
||||
|
||||
/// Load the client parameters from the compilation result.
|
||||
llvm::Expected<clientlib::ClientParameters> virtual loadClientParameters(
|
||||
CompilationResult &result) = 0;
|
||||
|
||||
/// Load the compilation feedback from the compilation result.
|
||||
llvm::Expected<CompilationFeedback> virtual loadCompilationFeedback(
|
||||
CompilationResult &result) = 0;
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>> virtual serverCall(
|
||||
Lambda lambda, clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys) = 0;
|
||||
|
||||
/// Build the client KeySet from the client parameters.
|
||||
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
|
||||
keySet(clientlib::ClientParameters clientParameters,
|
||||
std::optional<clientlib::KeySetCache> cache, uint64_t seed_msb = 0,
|
||||
uint64_t seed_lsb = 0) {
|
||||
std::shared_ptr<clientlib::KeySetCache> cachePtr;
|
||||
if (cache.has_value()) {
|
||||
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.value());
|
||||
}
|
||||
auto keySet = clientlib::KeySetCache::generate(cachePtr, clientParameters,
|
||||
seed_msb, seed_lsb);
|
||||
if (keySet.has_error()) {
|
||||
return StreamStringError(keySet.error().mesg);
|
||||
}
|
||||
return std::move(keySet.value());
|
||||
}
|
||||
|
||||
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
|
||||
exportArguments(clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<const LambdaArgument *> args) {
|
||||
return LambdaArgumentAdaptor::exportArguments(args, clientParameters,
|
||||
keySet);
|
||||
}
|
||||
|
||||
template <typename ResT>
|
||||
static llvm::Expected<ResT> call(Lambda lambda,
|
||||
clientlib::PublicArguments &publicArguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
// Call the lambda
|
||||
auto publicResult = LambdaSupport<Lambda, CompilationResult>().serverCall(
|
||||
lambda, publicArguments, evaluationKeys);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
// Decrypt the result
|
||||
return typedResult<ResT>(keySet, **publicResult);
|
||||
}
|
||||
};
|
||||
|
||||
template <class LambdaSupport> class ClientServer {
|
||||
public:
|
||||
static llvm::Expected<ClientServer>
|
||||
create(llvm::StringRef program,
|
||||
CompilationOptions options = CompilationOptions("main"),
|
||||
std::optional<clientlib::KeySetCache> cache = {},
|
||||
LambdaSupport support = LambdaSupport()) {
|
||||
auto compilationResult = support.compile(program, options);
|
||||
if (auto err = compilationResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto lambda = support.loadServerLambda(**compilationResult);
|
||||
if (auto err = lambda.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto clientParameters = support.loadClientParameters(**compilationResult);
|
||||
if (auto err = clientParameters.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto keySet = support.keySet(*clientParameters, cache);
|
||||
if (auto err = keySet.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
auto f = ClientServer();
|
||||
f.lambda = *lambda;
|
||||
f.compilationResult = std::move(*compilationResult);
|
||||
f.keySet = std::move(*keySet);
|
||||
f.clientParameters = *clientParameters;
|
||||
f.support = support;
|
||||
return std::move(f);
|
||||
}
|
||||
|
||||
template <typename ResT = uint64_t>
|
||||
llvm::Expected<ResT> operator()(llvm::ArrayRef<LambdaArgument *> args) {
|
||||
auto publicArguments = LambdaArgumentAdaptor::exportArguments(
|
||||
args, clientParameters, *this->keySet);
|
||||
|
||||
if (auto err = publicArguments.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
auto evaluationKeys = this->keySet->evaluationKeys();
|
||||
auto publicResult =
|
||||
support.serverCall(lambda, **publicArguments, evaluationKeys);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
template <typename T, typename ResT = uint64_t>
|
||||
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::create(
|
||||
/*simulation*/ false, *keySet, args);
|
||||
if (encryptedArgs.has_error()) {
|
||||
return StreamStringError(encryptedArgs.error().mesg);
|
||||
}
|
||||
auto publicArguments =
|
||||
encryptedArgs.value()->exportPublicArguments(clientParameters);
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
auto evaluationKeys = keySet->evaluationKeys();
|
||||
auto publicResult =
|
||||
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
template <typename ResT = uint64_t, typename... Args>
|
||||
llvm::Expected<ResT> operator()(const Args... args) {
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::create(
|
||||
/*simulation*/ false, *keySet, args...);
|
||||
if (encryptedArgs.has_error()) {
|
||||
return StreamStringError(encryptedArgs.error().mesg);
|
||||
}
|
||||
auto publicArguments =
|
||||
encryptedArgs.value()->exportPublicArguments(clientParameters);
|
||||
if (publicArguments.has_error()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
auto evaluationKeys = keySet->evaluationKeys();
|
||||
auto publicResult =
|
||||
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
return typedResult<ResT>(*keySet, **publicResult);
|
||||
}
|
||||
|
||||
private:
|
||||
typename LambdaSupport::lambda lambda;
|
||||
std::unique_ptr<typename LambdaSupport::compilationResult> compilationResult;
|
||||
std::unique_ptr<clientlib::KeySet> keySet;
|
||||
clientlib::ClientParameters clientParameters;
|
||||
LambdaSupport support;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,193 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_LIBRARY_SUPPORT
|
||||
#define CONCRETELANG_SUPPORT_LIBRARY_SUPPORT
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
#include <concretelang/ServerLib/ServerLambda.h>
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/LambdaSupport.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
namespace serverlib = ::concretelang::serverlib;
|
||||
|
||||
/// LibraryCompilationResult is the result of a compilation to a library.
|
||||
struct LibraryCompilationResult {
|
||||
/// The output directory path where the compilation artifacts have been
|
||||
/// generated.
|
||||
std::string outputDirPath;
|
||||
std::string funcName;
|
||||
};
|
||||
|
||||
class LibrarySupport
|
||||
: public LambdaSupport<serverlib::ServerLambda, LibraryCompilationResult> {
|
||||
|
||||
public:
|
||||
LibrarySupport(std::string outputPath, std::string runtimeLibraryPath = "",
|
||||
bool generateSharedLib = true, bool generateStaticLib = true,
|
||||
bool generateClientParameters = true,
|
||||
bool generateCompilationFeedback = true,
|
||||
bool generateCppHeader = true)
|
||||
: outputPath(outputPath), runtimeLibraryPath(runtimeLibraryPath),
|
||||
generateSharedLib(generateSharedLib),
|
||||
generateStaticLib(generateStaticLib),
|
||||
generateClientParameters(generateClientParameters),
|
||||
generateCompilationFeedback(generateCompilationFeedback),
|
||||
generateCppHeader(generateCppHeader) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, CompilationOptions options) override {
|
||||
// Setup the compiler engine
|
||||
auto context = CompilationContext::createShared();
|
||||
concretelang::CompilerEngine engine(context);
|
||||
engine.setCompilationOptions(options);
|
||||
return compileWithEngine<llvm::SourceMgr &>(program, options, engine);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compile(mlir::ModuleOp program,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> cctx,
|
||||
CompilationOptions options) override {
|
||||
// Setup the compiler engine
|
||||
concretelang::CompilerEngine engine(cctx);
|
||||
engine.setCompilationOptions(options);
|
||||
return compileWithEngine<mlir::ModuleOp>(program, options, engine);
|
||||
}
|
||||
using LambdaSupport::compile;
|
||||
|
||||
/// Load the server lambda from the compilation result.
|
||||
llvm::Expected<serverlib::ServerLambda>
|
||||
loadServerLambda(LibraryCompilationResult &result) override {
|
||||
auto lambda =
|
||||
serverlib::ServerLambda::load(result.funcName, result.outputDirPath);
|
||||
if (lambda.has_error()) {
|
||||
return StreamStringError(lambda.error().mesg);
|
||||
}
|
||||
return lambda.value();
|
||||
}
|
||||
|
||||
/// Load the client parameters from the compilation result.
|
||||
llvm::Expected<clientlib::ClientParameters>
|
||||
loadClientParameters(LibraryCompilationResult &result) override {
|
||||
auto path =
|
||||
CompilerEngine::Library::getClientParametersPath(result.outputDirPath);
|
||||
auto params = ClientParameters::load(path);
|
||||
if (params.has_error()) {
|
||||
return StreamStringError(params.error().mesg);
|
||||
}
|
||||
auto param = llvm::find_if(params.value(), [&](ClientParameters param) {
|
||||
return param.functionName == result.funcName;
|
||||
});
|
||||
if (param == params.value().end()) {
|
||||
return StreamStringError("ClientLambda: cannot find function(")
|
||||
<< result.funcName << ") in client parameters path(" << path
|
||||
<< ")";
|
||||
}
|
||||
return *param;
|
||||
}
|
||||
|
||||
std::string getFuncName() {
|
||||
auto path = CompilerEngine::Library::getClientParametersPath(outputPath);
|
||||
auto params = ClientParameters::load(path);
|
||||
if (params.has_error() || params.value().empty()) {
|
||||
return "";
|
||||
}
|
||||
return params.value().front().functionName;
|
||||
}
|
||||
|
||||
/// Load the the compilation result if circuit already compiled
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
loadCompilationResult() {
|
||||
auto funcName = getFuncName();
|
||||
if (funcName.empty()) {
|
||||
return StreamStringError("couldn't find function name");
|
||||
}
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->outputDirPath = outputPath;
|
||||
result->funcName = funcName;
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
llvm::Expected<CompilationFeedback>
|
||||
loadCompilationFeedback(LibraryCompilationResult &result) override {
|
||||
auto path = CompilerEngine::Library::getCompilationFeedbackPath(
|
||||
result.outputDirPath);
|
||||
auto feedback = CompilationFeedback::load(path);
|
||||
if (feedback.has_error()) {
|
||||
return StreamStringError(feedback.error().mesg);
|
||||
}
|
||||
return feedback.value();
|
||||
}
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
serverCall(serverlib::ServerLambda lambda, clientlib::PublicArguments &args,
|
||||
clientlib::EvaluationKeys &evaluationKeys) override {
|
||||
return lambda.call(args, evaluationKeys);
|
||||
}
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
simulate(serverlib::ServerLambda lambda, clientlib::PublicArguments &args) {
|
||||
return lambda.call(args, {}, /*simulation*/ true);
|
||||
}
|
||||
|
||||
/// Get path to shared library
|
||||
std::string getSharedLibPath() {
|
||||
return CompilerEngine::Library::getSharedLibraryPath(outputPath);
|
||||
}
|
||||
|
||||
/// Get path to client parameters file
|
||||
std::string getClientParametersPath() {
|
||||
return CompilerEngine::Library::getClientParametersPath(outputPath);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string outputPath;
|
||||
std::string runtimeLibraryPath;
|
||||
/// Flags to select generated artifacts
|
||||
bool generateSharedLib;
|
||||
bool generateStaticLib;
|
||||
bool generateClientParameters;
|
||||
bool generateCompilationFeedback;
|
||||
bool generateCppHeader;
|
||||
|
||||
template <typename T>
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compileWithEngine(T program, CompilationOptions options,
|
||||
concretelang::CompilerEngine &engine) {
|
||||
// Compile to a library
|
||||
auto library = engine.compile(
|
||||
program, outputPath, runtimeLibraryPath, generateSharedLib,
|
||||
generateStaticLib, generateClientParameters,
|
||||
generateCompilationFeedback, generateCppHeader);
|
||||
if (auto err = library.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!options.clientParametersFuncName.has_value()) {
|
||||
return StreamStringError("Need to have a funcname to compile library");
|
||||
}
|
||||
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->outputDirPath = outputPath;
|
||||
result->funcName = *options.clientParametersFuncName;
|
||||
return std::move(result);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -6,12 +6,11 @@
|
||||
#ifndef CONCRETELANG_SUPPORT_PIPELINE_H_
|
||||
#define CONCRETELANG_SUPPORT_PIPELINE_H_
|
||||
|
||||
#include <llvm/IR/Module.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
#include <mlir/Transforms/Passes.h>
|
||||
|
||||
#include <concretelang/Support/V0Parameters.h>
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
@@ -109,8 +108,7 @@ mlir::LogicalResult extractSDFGOps(mlir::MLIRContext &context,
|
||||
|
||||
mlir::LogicalResult
|
||||
addRuntimeContext(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool simulation);
|
||||
std::function<bool(mlir::Pass *)> enablePass);
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_PROGRAMINFOGENERATION_H_
|
||||
#define CONCRETELANG_SUPPORT_PROGRAMINFOGENERATION_H_
|
||||
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Support/Encodings.h"
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include <memory>
|
||||
|
||||
using concretelang::protocol::Message;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
llvm::Expected<Message<concreteprotocol::ProgramInfo>>
|
||||
createProgramInfoFromTfheDialect(
|
||||
mlir::ModuleOp module, llvm::StringRef functionName, int bitsOfSecurity,
|
||||
Message<concreteprotocol::CircuitEncodingInfo> &encodings);
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -6,14 +6,10 @@
|
||||
#ifndef CONCRETELANG_SUPPORT_UTILS_H_
|
||||
#define CONCRETELANG_SUPPORT_UTILS_H_
|
||||
|
||||
#include <concretelang/ClientLib/ClientLambda.h>
|
||||
#include <concretelang/ClientLib/KeySet.h>
|
||||
#include <concretelang/ClientLib/PublicArguments.h>
|
||||
#include <concretelang/ClientLib/Serializers.h>
|
||||
#include <concretelang/Runtime/context.h>
|
||||
#include <concretelang/ServerLib/ServerLambda.h>
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <llvm/ADT/SmallVector.h>
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
namespace concretelang {
|
||||
|
||||
@@ -28,161 +24,6 @@ std::string makePackedFunctionName(llvm::StringRef name);
|
||||
// and two array of rank size for sizes and strides.
|
||||
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank);
|
||||
|
||||
template <typename Lambda>
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
|
||||
std::vector<void *> preparedInputArgs,
|
||||
std::optional<clientlib::EvaluationKeys> evaluationKeys,
|
||||
bool simulation = false) {
|
||||
// invokeRaw needs to have pointers on arguments and a pointers on the result
|
||||
// as last argument.
|
||||
// Prepare the outputs vector to store the output value of the lambda.
|
||||
auto numOutputs = 0;
|
||||
for (auto &output : clientParameters.outputs) {
|
||||
auto shape = clientParameters.bufferShape(output, simulation);
|
||||
if (shape.size() == 0) {
|
||||
// scalar gate
|
||||
numOutputs += 1;
|
||||
} else {
|
||||
// buffer gate
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
|
||||
}
|
||||
}
|
||||
std::vector<uint64_t> outputs(numOutputs);
|
||||
|
||||
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
|
||||
// inputs and outputs.
|
||||
std::vector<void *> rawArgs(preparedInputArgs.size() +
|
||||
(simulation ? 0 : 1) /*runtime context*/ +
|
||||
1 /* outputs */
|
||||
);
|
||||
size_t i = 0;
|
||||
// Pointers on inputs
|
||||
for (auto &arg : preparedInputArgs) {
|
||||
rawArgs[i++] = &arg;
|
||||
}
|
||||
|
||||
// Some calls require the runtime context while others don't (in simulation)
|
||||
std::unique_ptr<mlir::concretelang::RuntimeContext> runtimeContext;
|
||||
mlir::concretelang::RuntimeContext *rtCtxPtr;
|
||||
if (!simulation) {
|
||||
assert(evaluationKeys.has_value() &&
|
||||
"evaluation keys are required if not in simulation");
|
||||
runtimeContext = std::make_unique<mlir::concretelang::RuntimeContext>(
|
||||
evaluationKeys.value());
|
||||
// Pointer on runtime context, the rawArgs take pointer on actual value that
|
||||
// is passed to the compiled function.
|
||||
rtCtxPtr = runtimeContext.get();
|
||||
rawArgs[i++] = &rtCtxPtr;
|
||||
}
|
||||
|
||||
// Outputs
|
||||
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
|
||||
|
||||
// Invoke
|
||||
if (auto err = lambda->invokeRaw(rawArgs)) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
// Store the result to the PublicResult
|
||||
std::vector<clientlib::SharedScalarOrTensorData> buffers;
|
||||
{
|
||||
size_t outputOffset = 0;
|
||||
for (auto &output : clientParameters.outputs) {
|
||||
auto shape = clientParameters.bufferShape(output, simulation);
|
||||
if (shape.size() == 0) {
|
||||
if (simulation) {
|
||||
// value is encrypted (simulated)
|
||||
auto value = concretelang::clientlib::ScalarOrTensorData(
|
||||
concretelang::clientlib::ScalarData(
|
||||
outputs[outputOffset++],
|
||||
clientlib::EncryptedScalarElementType,
|
||||
clientlib::EncryptedScalarElementWidth));
|
||||
auto sharedValue =
|
||||
clientlib::SharedScalarOrTensorData(std::move(value));
|
||||
|
||||
buffers.push_back(sharedValue);
|
||||
} else {
|
||||
// plain scalar
|
||||
auto value = concretelang::clientlib::ScalarOrTensorData(
|
||||
concretelang::clientlib::ScalarData(outputs[outputOffset++],
|
||||
output.shape.sign,
|
||||
output.shape.width));
|
||||
auto sharedValue =
|
||||
clientlib::SharedScalarOrTensorData(std::move(value));
|
||||
|
||||
buffers.push_back(sharedValue);
|
||||
}
|
||||
} else {
|
||||
// buffer gate
|
||||
auto rank = shape.size();
|
||||
auto allocated = (uint64_t *)outputs[outputOffset++];
|
||||
auto aligned = (uint64_t *)outputs[outputOffset++];
|
||||
auto offset = (size_t)outputs[outputOffset++];
|
||||
size_t *sizes = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
size_t *strides = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
|
||||
size_t elementWidth = (output.isEncrypted())
|
||||
? clientlib::EncryptedScalarElementWidth
|
||||
: output.shape.width;
|
||||
|
||||
bool sign = (output.isEncrypted()) ? false : output.shape.sign;
|
||||
|
||||
auto value = concretelang::clientlib::ScalarOrTensorData(
|
||||
clientlib::tensorDataFromMemRef(rank, elementWidth, sign, allocated,
|
||||
aligned, offset, sizes, strides));
|
||||
auto sharedValue =
|
||||
clientlib::SharedScalarOrTensorData(std::move(value));
|
||||
|
||||
buffers.push_back(sharedValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return clientlib::PublicResult::fromBuffers(clientParameters,
|
||||
std::move(buffers));
|
||||
}
|
||||
|
||||
template <typename Lambda>
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
invokeRawOnLambda(Lambda *lambda, clientlib::PublicArguments &arguments,
|
||||
std::optional<clientlib::EvaluationKeys> evaluationKeys,
|
||||
bool simulation = false) {
|
||||
// Prepare arguments with the right calling convention
|
||||
std::vector<void *> preparedArgs;
|
||||
for (auto &sharedArg : arguments.getArguments()) {
|
||||
clientlib::ScalarOrTensorData &arg = sharedArg.get();
|
||||
if (arg.isScalar()) {
|
||||
auto scalar = arg.getScalar().getValueAsU64();
|
||||
preparedArgs.push_back((void *)scalar);
|
||||
} else {
|
||||
clientlib::TensorData &td = arg.getTensor();
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back(td.getValuesAsOpaquePointer());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t size : td.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = td.getNumElements();
|
||||
for (size_t size : td.getDimensions()) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
return invokeRawOnLambda(lambda, arguments.getClientParameters(),
|
||||
preparedArgs, evaluationKeys, simulation);
|
||||
}
|
||||
|
||||
template <typename V, unsigned int N>
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
const llvm::SmallVector<V, N> vect) {
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_TESTCIRCUIT_H
|
||||
#define CONCRETELANG_TESTLIB_TESTCIRCUIT_H
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/ClientLib/ClientLib.h"
|
||||
#include "concretelang/Common/Csprng.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Common/Values.h"
|
||||
#include "concretelang/ServerLib/ServerLib.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "tests_tools/keySetCache.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
|
||||
using concretelang::clientlib::ClientCircuit;
|
||||
using concretelang::clientlib::ClientProgram;
|
||||
using concretelang::csprng::ConcreteCSPRNG;
|
||||
using concretelang::error::Result;
|
||||
using concretelang::keysets::Keyset;
|
||||
using concretelang::serverlib::ServerCircuit;
|
||||
using concretelang::serverlib::ServerProgram;
|
||||
using concretelang::values::TransportValue;
|
||||
using concretelang::values::Value;
|
||||
|
||||
namespace concretelang {
|
||||
namespace testlib {
|
||||
|
||||
class TestCircuit {
|
||||
|
||||
public:
|
||||
static Result<TestCircuit>
|
||||
create(Keyset keyset, Message<concreteprotocol::ProgramInfo> programInfo,
|
||||
std::string sharedLibPath, uint64_t seedMsb, uint64_t seedLsb,
|
||||
bool useSimulation = false) {
|
||||
OUTCOME_TRY(auto serverProgram,
|
||||
ServerProgram::load(programInfo, sharedLibPath, useSimulation));
|
||||
OUTCOME_TRY(auto serverCircuit,
|
||||
serverProgram.getServerCircuit(
|
||||
programInfo.asReader().getCircuits()[0].getName()));
|
||||
__uint128_t seed = seedMsb;
|
||||
seed <<= 64;
|
||||
seed += seedLsb;
|
||||
std::shared_ptr<CSPRNG> csprng = std::make_shared<ConcreteCSPRNG>(seed);
|
||||
OUTCOME_TRY(auto clientProgram,
|
||||
ClientProgram::create(programInfo, keyset.client, csprng,
|
||||
useSimulation));
|
||||
OUTCOME_TRY(auto clientCircuit,
|
||||
clientProgram.getClientCircuit(
|
||||
programInfo.asReader().getCircuits()[0].getName()));
|
||||
auto artifactFolder = std::filesystem::path(sharedLibPath).parent_path();
|
||||
return TestCircuit(clientCircuit, serverCircuit, useSimulation,
|
||||
artifactFolder, keyset);
|
||||
}
|
||||
|
||||
TestCircuit(ClientCircuit clientCircuit, ServerCircuit serverCircuit,
|
||||
bool useSimulation, Keyset keyset)
|
||||
: clientCircuit(clientCircuit), serverCircuit(serverCircuit),
|
||||
useSimulation(useSimulation), keyset(keyset) {}
|
||||
|
||||
Result<std::vector<Value>> call(std::vector<Value> inputs) {
|
||||
auto preparedArgs = std::vector<TransportValue>();
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
OUTCOME_TRY(auto preparedInput, clientCircuit.prepareInput(inputs[i], i));
|
||||
preparedArgs.push_back(preparedInput);
|
||||
}
|
||||
std::vector<TransportValue> returns;
|
||||
if (useSimulation) {
|
||||
OUTCOME_TRY(returns, serverCircuit.simulate(preparedArgs));
|
||||
} else {
|
||||
OUTCOME_TRY(returns, serverCircuit.call(keyset.server, preparedArgs));
|
||||
}
|
||||
std::vector<Value> processedOutputs(returns.size());
|
||||
for (size_t i = 0; i < processedOutputs.size(); i++) {
|
||||
OUTCOME_TRY(processedOutputs[i],
|
||||
clientCircuit.processOutput(returns[i], i));
|
||||
}
|
||||
return processedOutputs;
|
||||
}
|
||||
|
||||
std::string getArtifactFolder() { return artifactFolder; }
|
||||
|
||||
private:
|
||||
TestCircuit(ClientCircuit clientCircuit, ServerCircuit serverCircuit,
|
||||
bool useSimulation, std::string artifactFolder, Keyset keyset)
|
||||
: clientCircuit(clientCircuit), serverCircuit(serverCircuit),
|
||||
useSimulation(useSimulation), artifactFolder(artifactFolder),
|
||||
keyset(keyset){};
|
||||
ClientCircuit clientCircuit;
|
||||
ServerCircuit serverCircuit;
|
||||
bool useSimulation;
|
||||
std::string artifactFolder;
|
||||
Keyset keyset;
|
||||
};
|
||||
|
||||
TestCircuit load(mlir::concretelang::CompilerEngine::Library compiled) {
|
||||
auto keyset =
|
||||
getTestKeySetCachePtr()
|
||||
->getKeyset(compiled.getProgramInfo().asReader().getKeyset(), 0, 0)
|
||||
.value();
|
||||
return TestCircuit::create(
|
||||
keyset, compiled.getProgramInfo().asReader(),
|
||||
compiled.getSharedLibraryPath(compiled.getOutputDirPath()), 0, 0,
|
||||
false)
|
||||
.value();
|
||||
}
|
||||
|
||||
const std::string FUNCNAME = "main";
|
||||
|
||||
std::string getSystemTempFolderPath() {
|
||||
llvm::SmallString<0> tempPath;
|
||||
llvm::sys::path::system_temp_directory(true, tempPath);
|
||||
return std::string(tempPath);
|
||||
}
|
||||
|
||||
std::string createTempFolderIn(const std::string &rootFolder) {
|
||||
std::srand(std::time(nullptr));
|
||||
auto new_path = [=]() {
|
||||
llvm::SmallString<0> outputPath;
|
||||
llvm::sys::path::append(outputPath, rootFolder);
|
||||
std::string uid = std::to_string(
|
||||
std::hash<std::thread::id>()(std::this_thread::get_id()));
|
||||
uid.append("-");
|
||||
uid.append(std::to_string(std::rand()));
|
||||
llvm::sys::path::append(outputPath, uid);
|
||||
return std::string(outputPath);
|
||||
};
|
||||
|
||||
// Macos sometimes fail to create new directories. We have to retry a few
|
||||
// times.
|
||||
for (size_t i = 0; i < 5; i++) {
|
||||
auto pathString = new_path();
|
||||
auto ec = std::error_code();
|
||||
if (!std::filesystem::create_directory(pathString, ec)) {
|
||||
std::cout << "Failed to create directory ";
|
||||
std::cout << pathString;
|
||||
std::cout << " Reason: ";
|
||||
std::cout << ec.message();
|
||||
std::cout << " Retrying....\n";
|
||||
std::cout.flush();
|
||||
} else {
|
||||
std::cout << "Using artifact folder ";
|
||||
std::cout << pathString;
|
||||
std::cout << "\n";
|
||||
std::cout.flush();
|
||||
return pathString;
|
||||
}
|
||||
}
|
||||
std::cout << "Failed to create temp directory 5 times. Aborting...\n";
|
||||
std::cout.flush();
|
||||
assert(false);
|
||||
}
|
||||
|
||||
void deleteFolder(const std::string &folder) {
|
||||
if (!folder.empty()) {
|
||||
auto ec = std::error_code();
|
||||
if (!std::filesystem::remove_all(folder, ec)) {
|
||||
std::cout << "Failed to delete directory ";
|
||||
std::cout << folder;
|
||||
std::cout << " Reason: ";
|
||||
std::cout << ec.message();
|
||||
std::cout.flush();
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<uint8_t> values_3bits() { return {0, 1, 2, 5, 7}; }
|
||||
std::vector<uint8_t> values_6bits() { return {0, 1, 2, 13, 22, 59, 62, 63}; }
|
||||
std::vector<uint8_t> values_7bits() { return {0, 1, 2, 63, 64, 65, 125, 126}; }
|
||||
|
||||
} // namespace testlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,117 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_TESTLIB_TEST_TYPED_LAMBDA_H
|
||||
#define CONCRETELANG_TESTLIB_TEST_TYPED_LAMBDA_H
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace testlib {
|
||||
|
||||
using concretelang::clientlib::ClientLambda;
|
||||
using concretelang::clientlib::ClientParameters;
|
||||
using concretelang::clientlib::KeySet;
|
||||
using concretelang::clientlib::KeySetCache;
|
||||
using concretelang::error::StringError;
|
||||
using concretelang::serverlib::ServerLambda;
|
||||
|
||||
inline void freeStringMemory(std::string &s) {
|
||||
std::string empty;
|
||||
s.swap(empty);
|
||||
}
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
class TestTypedLambda
|
||||
: public concretelang::clientlib::TypedClientLambda<Result, Args...> {
|
||||
|
||||
template <typename Result_, typename... Args_>
|
||||
using TypedClientLambda =
|
||||
concretelang::clientlib::TypedClientLambda<Result_, Args_...>;
|
||||
|
||||
public:
|
||||
static outcome::checked<TestTypedLambda, StringError>
|
||||
load(std::string funcName, std::string outputLib, uint64_t seed_msb = 0,
|
||||
uint64_t seed_lsb = 0,
|
||||
std::shared_ptr<KeySetCache> unsecure_cache = nullptr) {
|
||||
std::string jsonPath =
|
||||
mlir::concretelang::CompilerEngine::Library::getClientParametersPath(
|
||||
outputLib);
|
||||
OUTCOME_TRY(auto cLambda, ClientLambda::load(funcName, jsonPath));
|
||||
OUTCOME_TRY(auto sLambda, ServerLambda::load(funcName, outputLib));
|
||||
OUTCOME_TRY(std::shared_ptr<KeySet> keySet,
|
||||
KeySetCache::generate(unsecure_cache, cLambda.clientParameters,
|
||||
seed_msb, seed_lsb));
|
||||
return TestTypedLambda(cLambda, sLambda, keySet);
|
||||
}
|
||||
|
||||
TestTypedLambda(ClientLambda &cLambda, ServerLambda &sLambda,
|
||||
std::shared_ptr<KeySet> keySet)
|
||||
: TypedClientLambda<Result, Args...>(cLambda), serverLambda(sLambda),
|
||||
keySet(keySet) {}
|
||||
|
||||
TestTypedLambda(TypedClientLambda<Result, Args...> &cLambda,
|
||||
ServerLambda &sLambda, std::shared_ptr<KeySet> keySet)
|
||||
: TypedClientLambda<Result, Args...>(cLambda), serverLambda(sLambda),
|
||||
keySet(keySet) {}
|
||||
|
||||
outcome::checked<Result, StringError> call(Args... args) {
|
||||
// std::string message;
|
||||
|
||||
// client stream
|
||||
// std::ostringstream clientOuput(std::ios::binary);
|
||||
// client argument encryption
|
||||
OUTCOME_TRY(auto encryptedArgs,
|
||||
clientlib::EncryptedArguments::create(/*simulation*/ false,
|
||||
*keySet, args...));
|
||||
OUTCOME_TRY(auto publicArgument,
|
||||
encryptedArgs->exportPublicArguments(this->clientParameters));
|
||||
// client argument serialization
|
||||
// publicArgument->serialize(clientOuput);
|
||||
// message = clientOuput.str();
|
||||
|
||||
// server stream
|
||||
// std::istringstream serverInput(message, std::ios::binary);
|
||||
// freeStringMemory(message);
|
||||
//
|
||||
// OUTCOME_TRY(auto publicArguments,
|
||||
// clientlib::PublicArguments::unserialize(
|
||||
// this->clientParameters,
|
||||
// serverInput));
|
||||
|
||||
// server function call
|
||||
auto evaluationKeys = keySet->evaluationKeys();
|
||||
auto publicResult = serverLambda.call(*publicArgument, evaluationKeys);
|
||||
if (!publicResult) {
|
||||
return StringError("failed calling function");
|
||||
}
|
||||
|
||||
// client result decryption
|
||||
return this->decryptResult(*keySet, *(publicResult.get()));
|
||||
}
|
||||
|
||||
private:
|
||||
ServerLambda serverLambda;
|
||||
std::shared_ptr<KeySet> keySet;
|
||||
};
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
static TestTypedLambda<Result, Args...> TestTypedLambdaFrom(
|
||||
concretelang::clientlib::TypedClientLambda<Result, Args...> &cLambda,
|
||||
ServerLambda &sLambda, std::shared_ptr<KeySet> keySet) {
|
||||
return TestTypedLambda<Result, Args...>(cLambda, sLambda, keySet);
|
||||
}
|
||||
|
||||
} // namespace testlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
@@ -1,3 +1,5 @@
|
||||
add_compile_options(-fsized-deallocation)
|
||||
|
||||
add_mlir_library(
|
||||
AnalysisUtils
|
||||
Utils.cpp
|
||||
|
||||
@@ -8,7 +8,14 @@ add_compile_options(-fexceptions)
|
||||
# ######################################################################################################################
|
||||
set(LLVM_OPTIONAL_SOURCES CompilerAPIModule.cpp ConcretelangModule.cpp FHEModule.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(CONCRETELANGPySupport CompilerEngine.cpp LINK_LIBS PUBLIC MLIRCAPIIR ConcretelangSupport)
|
||||
add_mlir_public_c_api_library(
|
||||
CONCRETELANGPySupport
|
||||
CompilerEngine.cpp
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRCAPIIR
|
||||
ConcretelangSupport
|
||||
ConcretelangCommon)
|
||||
|
||||
# ######################################################################################################################
|
||||
# Decalare native Python extension
|
||||
@@ -28,9 +35,6 @@ declare_mlir_python_extension(
|
||||
CompilerAPIModule.cpp
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
MLIRCAPIRegisterEverything
|
||||
CONCRETELANGCAPIFHE
|
||||
CONCRETELANGCAPIFHELINALG
|
||||
CONCRETELANGCAPITRACING
|
||||
CONCRETELANGPySupport)
|
||||
|
||||
# ######################################################################################################################
|
||||
@@ -48,9 +52,6 @@ declare_mlir_python_sources(
|
||||
concrete/compiler/compilation_context.py
|
||||
concrete/compiler/compilation_feedback.py
|
||||
concrete/compiler/compilation_options.py
|
||||
concrete/compiler/jit_compilation_result.py
|
||||
concrete/compiler/jit_support.py
|
||||
concrete/compiler/jit_lambda.py
|
||||
concrete/compiler/key_set_cache.py
|
||||
concrete/compiler/key_set.py
|
||||
concrete/compiler/lambda_argument.py
|
||||
|
||||
@@ -4,10 +4,13 @@
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Bindings/Python/CompilerAPIModule.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Bindings/Python/CompilerEngine.h"
|
||||
#include "concretelang/ClientLib/ClientLib.h"
|
||||
#include "concretelang/Common/Compat.h"
|
||||
#include "concretelang/Common/Csprng.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOpsDialect.h.inc"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
#include "concretelang/Support/logging.h"
|
||||
#include <llvm/Support/Debug.h>
|
||||
#include <mlir-c/Bindings/Python/Interop.h>
|
||||
@@ -24,7 +27,6 @@
|
||||
#include <string>
|
||||
|
||||
using mlir::concretelang::CompilationOptions;
|
||||
using mlir::concretelang::JITSupport;
|
||||
using mlir::concretelang::LambdaArgument;
|
||||
|
||||
class SignalGuard {
|
||||
@@ -75,7 +77,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
[](std::string funcname) { return CompilationOptions(funcname); }))
|
||||
.def("set_funcname",
|
||||
[](CompilationOptions &options, std::string funcname) {
|
||||
options.clientParametersFuncName = funcname;
|
||||
options.mainFuncName = funcname;
|
||||
})
|
||||
.def("set_verify_diagnostics",
|
||||
[](CompilationOptions &options, bool b) {
|
||||
@@ -207,11 +209,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
"memory_usage_per_location",
|
||||
&mlir::concretelang::CompilationFeedback::memoryUsagePerLoc);
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
m, "JITCompilationResult");
|
||||
pybind11::class_<mlir::concretelang::JITLambda,
|
||||
std::shared_ptr<mlir::concretelang::JITLambda>>(m,
|
||||
"JITLambda");
|
||||
pybind11::class_<mlir::concretelang::CompilationContext,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext>>(
|
||||
m, "CompilationContext")
|
||||
@@ -224,51 +221,6 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
return pybind11::reinterpret_steal<pybind11::object>(
|
||||
mlirPythonContextToCapsule(wrap(mlirCtx)));
|
||||
});
|
||||
pybind11::class_<JITSupport_Py>(m, "JITSupport")
|
||||
.def(pybind11::init([](std::string runtimeLibPath) {
|
||||
return jit_support(runtimeLibPath);
|
||||
}))
|
||||
.def("compile",
|
||||
[](JITSupport_Py &support, std::string mlir_program,
|
||||
CompilationOptions options) {
|
||||
SignalGuard signalGuard;
|
||||
return jit_compile(support, mlir_program.c_str(), options);
|
||||
})
|
||||
.def("compile",
|
||||
[](JITSupport_Py &support, pybind11::object mlir_module,
|
||||
CompilationOptions options,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
|
||||
SignalGuard signalGuard;
|
||||
return jit_compile_module(
|
||||
support,
|
||||
unwrap(mlirPythonCapsuleToModule(mlir_module.ptr())).clone(),
|
||||
options, cctx);
|
||||
})
|
||||
.def("load_client_parameters",
|
||||
[](JITSupport_Py &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_client_parameters(support, result);
|
||||
})
|
||||
.def("load_compilation_feedback",
|
||||
[](JITSupport_Py &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_compilation_feedback(support, result);
|
||||
})
|
||||
.def(
|
||||
"load_server_lambda",
|
||||
[](JITSupport_Py &support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
return jit_load_server_lambda(support, result);
|
||||
},
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](JITSupport_Py &support, concretelang::JITLambda &lambda,
|
||||
clientlib::PublicArguments &publicArguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
SignalGuard signalGuard;
|
||||
return jit_server_call(support, lambda, publicArguments,
|
||||
evaluationKeys);
|
||||
});
|
||||
|
||||
pybind11::class_<mlir::concretelang::LibraryCompilationResult>(
|
||||
m, "LibraryCompilationResult")
|
||||
@@ -278,7 +230,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
funcname,
|
||||
};
|
||||
}));
|
||||
pybind11::class_<concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
|
||||
pybind11::class_<::concretelang::serverlib::ServerLambda>(m, "LibraryLambda");
|
||||
pybind11::class_<LibrarySupport_Py>(m, "LibrarySupport")
|
||||
.def(pybind11::init(
|
||||
[](std::string outputPath, std::string runtimeLibraryPath,
|
||||
@@ -319,21 +271,24 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def(
|
||||
"load_server_lambda",
|
||||
[](LibrarySupport_Py &support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
return library_load_server_lambda(support, result);
|
||||
mlir::concretelang::LibraryCompilationResult &result,
|
||||
bool useSimulation) {
|
||||
return library_load_server_lambda(support, result, useSimulation);
|
||||
},
|
||||
pybind11::return_value_policy::reference)
|
||||
.def("server_call",
|
||||
[](LibrarySupport_Py &support, serverlib::ServerLambda lambda,
|
||||
clientlib::PublicArguments &publicArguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
[](LibrarySupport_Py &support,
|
||||
::concretelang::serverlib::ServerLambda lambda,
|
||||
::concretelang::clientlib::PublicArguments &publicArguments,
|
||||
::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
SignalGuard signalGuard;
|
||||
return library_server_call(support, lambda, publicArguments,
|
||||
evaluationKeys);
|
||||
})
|
||||
.def("simulate",
|
||||
[](LibrarySupport_Py &support, serverlib::ServerLambda lambda,
|
||||
clientlib::PublicArguments &publicArguments) {
|
||||
[](LibrarySupport_Py &support,
|
||||
::concretelang::serverlib::ServerLambda lambda,
|
||||
::concretelang::clientlib::PublicArguments &publicArguments) {
|
||||
pybind11::gil_scoped_release release;
|
||||
return library_simulate(support, lambda, publicArguments);
|
||||
})
|
||||
@@ -341,8 +296,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
[](LibrarySupport_Py &support) {
|
||||
return library_get_shared_lib_path(support);
|
||||
})
|
||||
.def("get_client_parameters_path", [](LibrarySupport_Py &support) {
|
||||
return library_get_client_parameters_path(support);
|
||||
.def("get_program_info_path", [](LibrarySupport_Py &support) {
|
||||
return library_get_program_info_path(support);
|
||||
});
|
||||
|
||||
class ClientSupport {};
|
||||
@@ -350,120 +305,165 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def(pybind11::init())
|
||||
.def_static(
|
||||
"key_set",
|
||||
[](clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySetCache *cache, uint64_t seedMsb,
|
||||
[](::concretelang::clientlib::ClientParameters clientParameters,
|
||||
::concretelang::clientlib::KeySetCache *cache, uint64_t seedMsb,
|
||||
uint64_t seedLsb) {
|
||||
SignalGuard signalGuard;
|
||||
auto optCache = cache == nullptr
|
||||
? std::nullopt
|
||||
: std::optional<clientlib::KeySetCache>(*cache);
|
||||
auto optCache =
|
||||
cache == nullptr
|
||||
? std::nullopt
|
||||
: std::optional<::concretelang::clientlib::KeySetCache>(
|
||||
*cache);
|
||||
return key_set(clientParameters, optCache, seedMsb, seedLsb);
|
||||
},
|
||||
pybind11::arg().none(false), pybind11::arg().none(true),
|
||||
pybind11::arg("seedMsb") = 0, pybind11::arg("seedLsb") = 0)
|
||||
.def_static("encrypt_arguments",
|
||||
[](clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySet &keySet,
|
||||
std::vector<lambdaArgument> args) {
|
||||
std::vector<mlir::concretelang::LambdaArgument *> argsRef;
|
||||
for (auto i = 0u; i < args.size(); i++) {
|
||||
argsRef.push_back(args[i].ptr.get());
|
||||
}
|
||||
return encrypt_arguments(clientParameters, keySet, argsRef);
|
||||
})
|
||||
.def_static("decrypt_result", [](clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &publicResult) {
|
||||
return decrypt_result(keySet, publicResult);
|
||||
});
|
||||
pybind11::class_<clientlib::KeySetCache>(m, "KeySetCache")
|
||||
.def_static(
|
||||
"encrypt_arguments",
|
||||
[](::concretelang::clientlib::ClientParameters clientParameters,
|
||||
::concretelang::clientlib::KeySet &keySet,
|
||||
std::vector<lambdaArgument> args) {
|
||||
std::vector<mlir::concretelang::LambdaArgument *> argsRef;
|
||||
for (auto i = 0u; i < args.size(); i++) {
|
||||
argsRef.push_back(args[i].ptr.get());
|
||||
}
|
||||
return encrypt_arguments(clientParameters, keySet, argsRef);
|
||||
})
|
||||
.def_static(
|
||||
"decrypt_result",
|
||||
[](::concretelang::clientlib::ClientParameters clientParameters,
|
||||
::concretelang::clientlib::KeySet &keySet,
|
||||
::concretelang::clientlib::PublicResult &publicResult) {
|
||||
return decrypt_result(clientParameters, keySet, publicResult);
|
||||
});
|
||||
pybind11::class_<::concretelang::clientlib::KeySetCache>(m, "KeySetCache")
|
||||
.def(pybind11::init<std::string &>());
|
||||
|
||||
pybind11::class_<::concretelang::clientlib::LweSecretKeyParam>(
|
||||
m, "LweSecretKeyParam")
|
||||
.def_readonly("dimension",
|
||||
&::concretelang::clientlib::LweSecretKeyParam::dimension);
|
||||
.def("dimension", [](::concretelang::clientlib::LweSecretKeyParam &key) {
|
||||
return key.info.asReader().getParams().getLweDimension();
|
||||
});
|
||||
|
||||
pybind11::class_<::concretelang::clientlib::BootstrapKeyParam>(
|
||||
m, "BootstrapKeyParam")
|
||||
.def_readonly(
|
||||
"input_secret_key_id",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::inputSecretKeyID)
|
||||
.def_readonly(
|
||||
"output_secret_key_id",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::outputSecretKeyID)
|
||||
.def_readonly("level",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::level)
|
||||
.def_readonly("base_log",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::baseLog)
|
||||
.def_readonly(
|
||||
"glwe_dimension",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::glweDimension)
|
||||
.def_readonly("variance",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::variance)
|
||||
.def_readonly(
|
||||
"polynomial_size",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::polynomialSize)
|
||||
.def_readonly(
|
||||
"input_lwe_dimension",
|
||||
&::concretelang::clientlib::BootstrapKeyParam::inputLweDimension);
|
||||
.def("input_secret_key_id",
|
||||
[](::concretelang::clientlib::BootstrapKeyParam &key) {
|
||||
return key.info.asReader().getInputId();
|
||||
})
|
||||
.def("output_secret_key_id",
|
||||
[](::concretelang::clientlib::BootstrapKeyParam &key) {
|
||||
return key.info.asReader().getOutputId();
|
||||
})
|
||||
.def("level",
|
||||
[](::concretelang::clientlib::BootstrapKeyParam &key) {
|
||||
return key.info.asReader().getParams().getLevelCount();
|
||||
})
|
||||
.def("base_log",
|
||||
[](::concretelang::clientlib::BootstrapKeyParam &key) {
|
||||
return key.info.asReader().getParams().getBaseLog();
|
||||
})
|
||||
.def("glwe_dimension",
|
||||
[](::concretelang::clientlib::BootstrapKeyParam &key) {
|
||||
return key.info.asReader().getParams().getGlweDimension();
|
||||
})
|
||||
.def("variance",
|
||||
[](::concretelang::clientlib::BootstrapKeyParam &key) {
|
||||
return key.info.asReader().getParams().getVariance();
|
||||
})
|
||||
.def("polynomial_size",
|
||||
[](::concretelang::clientlib::BootstrapKeyParam &key) {
|
||||
return key.info.asReader().getParams().getPolynomialSize();
|
||||
})
|
||||
.def("input_lwe_dimension",
|
||||
[](::concretelang::clientlib::BootstrapKeyParam &key) {
|
||||
return key.info.asReader().getParams().getInputLweDimension();
|
||||
});
|
||||
|
||||
pybind11::class_<::concretelang::clientlib::KeyswitchKeyParam>(
|
||||
m, "KeyswitchKeyParam")
|
||||
.def_readonly(
|
||||
"input_secret_key_id",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::inputSecretKeyID)
|
||||
.def_readonly(
|
||||
"output_secret_key_id",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::outputSecretKeyID)
|
||||
.def_readonly("level",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::level)
|
||||
.def_readonly("base_log",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::baseLog)
|
||||
.def_readonly("variance",
|
||||
&::concretelang::clientlib::KeyswitchKeyParam::variance);
|
||||
.def("input_secret_key_id",
|
||||
[](::concretelang::clientlib::KeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getInputId();
|
||||
})
|
||||
.def("output_secret_key_id",
|
||||
[](::concretelang::clientlib::KeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getOutputId();
|
||||
})
|
||||
.def("level",
|
||||
[](::concretelang::clientlib::KeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getLevelCount();
|
||||
})
|
||||
.def("base_log",
|
||||
[](::concretelang::clientlib::KeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getBaseLog();
|
||||
})
|
||||
.def("variance", [](::concretelang::clientlib::KeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getVariance();
|
||||
});
|
||||
|
||||
pybind11::class_<::concretelang::clientlib::PackingKeyswitchKeyParam>(
|
||||
m, "PackingKeyswitchKeyParam")
|
||||
.def_readonly("input_secret_key_id",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::
|
||||
inputSecretKeyID)
|
||||
.def_readonly("output_secret_key_id",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::
|
||||
outputSecretKeyID)
|
||||
.def_readonly("level",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::level)
|
||||
.def_readonly(
|
||||
"base_log",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::baseLog)
|
||||
.def_readonly(
|
||||
"glwe_dimension",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::glweDimension)
|
||||
.def_readonly(
|
||||
"polynomial_size",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::polynomialSize)
|
||||
.def_readonly("input_lwe_dimension",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::
|
||||
inputLweDimension)
|
||||
.def_readonly(
|
||||
"variance",
|
||||
&::concretelang::clientlib::PackingKeyswitchKeyParam::variance);
|
||||
.def("input_secret_key_id",
|
||||
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getInputId();
|
||||
})
|
||||
.def("output_secret_key_id",
|
||||
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getOutputId();
|
||||
})
|
||||
.def("level",
|
||||
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getLevelCount();
|
||||
})
|
||||
.def("base_log",
|
||||
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getBaseLog();
|
||||
})
|
||||
.def("glwe_dimension",
|
||||
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getGlweDimension();
|
||||
})
|
||||
.def("polynomial_size",
|
||||
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getPolynomialSize();
|
||||
})
|
||||
.def("input_lwe_dimension",
|
||||
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getInputLweDimension();
|
||||
})
|
||||
.def("variance",
|
||||
[](::concretelang::clientlib::PackingKeyswitchKeyParam &key) {
|
||||
return key.info.asReader().getParams().getVariance();
|
||||
});
|
||||
|
||||
pybind11::class_<mlir::concretelang::ClientParameters>(m, "ClientParameters")
|
||||
pybind11::class_<::concretelang::clientlib::ClientParameters>(
|
||||
m, "ClientParameters")
|
||||
.def_static("deserialize",
|
||||
[](const pybind11::bytes &buffer) {
|
||||
return clientParametersUnserialize(buffer);
|
||||
})
|
||||
.def("serialize",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters) {
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return pybind11::bytes(
|
||||
clientParametersSerialize(clientParameters));
|
||||
})
|
||||
.def("output_signs",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters) {
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
std::vector<bool> result;
|
||||
for (auto output : clientParameters.outputs) {
|
||||
if (output.encryption.has_value()) {
|
||||
result.push_back(output.encryption.value().encoding.isSigned);
|
||||
for (auto output : clientParameters.programInfo.asReader()
|
||||
.getCircuits()[0]
|
||||
.getOutputs()) {
|
||||
if (output.getTypeInfo().hasLweCiphertext() &&
|
||||
output.getTypeInfo()
|
||||
.getLweCiphertext()
|
||||
.getEncoding()
|
||||
.hasInteger()) {
|
||||
result.push_back(output.getTypeInfo()
|
||||
.getLweCiphertext()
|
||||
.getEncoding()
|
||||
.getInteger()
|
||||
.getIsSigned());
|
||||
} else {
|
||||
result.push_back(true);
|
||||
}
|
||||
@@ -471,11 +471,21 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
return result;
|
||||
})
|
||||
.def("input_signs",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters) {
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
std::vector<bool> result;
|
||||
for (auto input : clientParameters.inputs) {
|
||||
if (input.encryption.has_value()) {
|
||||
result.push_back(input.encryption.value().encoding.isSigned);
|
||||
for (auto input : clientParameters.programInfo.asReader()
|
||||
.getCircuits()[0]
|
||||
.getInputs()) {
|
||||
if (input.getTypeInfo().hasLweCiphertext() &&
|
||||
input.getTypeInfo()
|
||||
.getLweCiphertext()
|
||||
.getEncoding()
|
||||
.hasInteger()) {
|
||||
result.push_back(input.getTypeInfo()
|
||||
.getLweCiphertext()
|
||||
.getEncoding()
|
||||
.getInteger()
|
||||
.getIsSigned());
|
||||
} else {
|
||||
result.push_back(true);
|
||||
}
|
||||
@@ -483,246 +493,244 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
return result;
|
||||
})
|
||||
.def_readonly("secret_keys",
|
||||
&mlir::concretelang::ClientParameters::secretKeys)
|
||||
&::concretelang::clientlib::ClientParameters::secretKeys)
|
||||
.def_readonly("bootstrap_keys",
|
||||
&mlir::concretelang::ClientParameters::bootstrapKeys)
|
||||
&::concretelang::clientlib::ClientParameters::bootstrapKeys)
|
||||
.def_readonly("keyswitch_keys",
|
||||
&mlir::concretelang::ClientParameters::keyswitchKeys)
|
||||
&::concretelang::clientlib::ClientParameters::keyswitchKeys)
|
||||
.def_readonly(
|
||||
"packing_keyswitch_keys",
|
||||
&mlir::concretelang::ClientParameters::packingKeyswitchKeys);
|
||||
&::concretelang::clientlib::ClientParameters::packingKeyswitchKeys);
|
||||
|
||||
pybind11::class_<clientlib::KeySet>(m, "KeySet")
|
||||
pybind11::class_<::concretelang::clientlib::KeySet>(m, "KeySet")
|
||||
.def_static("deserialize",
|
||||
[](const pybind11::bytes &buffer) {
|
||||
std::unique_ptr<KeySet> result = keySetUnserialize(buffer);
|
||||
std::unique_ptr<::concretelang::clientlib::KeySet> result =
|
||||
keySetUnserialize(buffer);
|
||||
return result;
|
||||
})
|
||||
.def("serialize",
|
||||
[](clientlib::KeySet &keySet) {
|
||||
[](::concretelang::clientlib::KeySet &keySet) {
|
||||
return pybind11::bytes(keySetSerialize(keySet));
|
||||
})
|
||||
.def("client_parameters",
|
||||
[](clientlib::KeySet &keySet) { return keySet.clientParameters(); })
|
||||
.def("get_evaluation_keys",
|
||||
[](clientlib::KeySet &keySet) { return keySet.evaluationKeys(); });
|
||||
[](::concretelang::clientlib::KeySet &keySet) {
|
||||
return ::concretelang::clientlib::EvaluationKeys{
|
||||
keySet.keyset.server};
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::SharedScalarOrTensorData>(m, "Value")
|
||||
pybind11::class_<::concretelang::clientlib::SharedScalarOrTensorData>(m,
|
||||
"Value")
|
||||
.def_static("deserialize",
|
||||
[](const pybind11::bytes &buffer) {
|
||||
return valueUnserialize(buffer);
|
||||
})
|
||||
.def("serialize", [](const clientlib::SharedScalarOrTensorData &value) {
|
||||
return pybind11::bytes(valueSerialize(value));
|
||||
});
|
||||
.def(
|
||||
"serialize",
|
||||
[](const ::concretelang::clientlib::SharedScalarOrTensorData &value) {
|
||||
return pybind11::bytes(valueSerialize(value));
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::ValueExporter>(m, "ValueExporter")
|
||||
.def_static("create",
|
||||
[](clientlib::KeySet &keySet,
|
||||
mlir::concretelang::ClientParameters &clientParameters) {
|
||||
return clientlib::ValueExporter(keySet, clientParameters);
|
||||
})
|
||||
pybind11::class_<::concretelang::clientlib::ValueExporter>(m, "ValueExporter")
|
||||
.def_static(
|
||||
"create",
|
||||
[](::concretelang::clientlib::KeySet &keySet,
|
||||
::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return createValueExporter(keySet, clientParameters);
|
||||
})
|
||||
.def("export_scalar",
|
||||
[](clientlib::ValueExporter &exporter, size_t position,
|
||||
int64_t value) {
|
||||
[](::concretelang::clientlib::ValueExporter &exporter,
|
||||
size_t position, int64_t value) {
|
||||
SignalGuard signalGuard;
|
||||
|
||||
outcome::checked<clientlib::ScalarOrTensorData, StringError>
|
||||
result = exporter.exportValue(value, position);
|
||||
auto info = exporter.circuit.getCircuitInfo()
|
||||
.asReader()
|
||||
.getInputs()[position];
|
||||
auto typeTransformer = getPythonTypeTransformer(info);
|
||||
auto result = exporter.circuit.prepareInput(
|
||||
typeTransformer({Tensor<int64_t>(value)}), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return clientlib::SharedScalarOrTensorData(
|
||||
std::move(result.value()));
|
||||
return ::concretelang::clientlib::SharedScalarOrTensorData{
|
||||
result.value()};
|
||||
})
|
||||
.def("export_tensor", [](clientlib::ValueExporter &exporter,
|
||||
.def("export_tensor", [](::concretelang::clientlib::ValueExporter
|
||||
&exporter,
|
||||
size_t position, std::vector<int64_t> values,
|
||||
std::vector<int64_t> shape) {
|
||||
SignalGuard signalGuard;
|
||||
|
||||
outcome::checked<clientlib::ScalarOrTensorData, StringError> result =
|
||||
exporter.exportValue(values.data(), shape, position);
|
||||
std::vector<size_t> dimensions(shape.begin(), shape.end());
|
||||
auto info =
|
||||
exporter.circuit.getCircuitInfo().asReader().getInputs()[position];
|
||||
auto typeTransformer = getPythonTypeTransformer(info);
|
||||
auto result = exporter.circuit.prepareInput(
|
||||
typeTransformer({Tensor<int64_t>(values, dimensions)}), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return clientlib::SharedScalarOrTensorData(std::move(result.value()));
|
||||
return ::concretelang::clientlib::SharedScalarOrTensorData{
|
||||
result.value()};
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::SimulatedValueExporter>(m,
|
||||
"SimulatedValueExporter")
|
||||
.def_static("create",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters) {
|
||||
return clientlib::SimulatedValueExporter(clientParameters);
|
||||
})
|
||||
pybind11::class_<::concretelang::clientlib::SimulatedValueExporter>(
|
||||
m, "SimulatedValueExporter")
|
||||
.def_static(
|
||||
"create",
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return createSimulatedValueExporter(clientParameters);
|
||||
})
|
||||
.def("export_scalar",
|
||||
[](clientlib::SimulatedValueExporter &exporter, size_t position,
|
||||
int64_t value) {
|
||||
outcome::checked<clientlib::ScalarOrTensorData, StringError>
|
||||
result = exporter.exportValue(value, position);
|
||||
[](::concretelang::clientlib::SimulatedValueExporter &exporter,
|
||||
size_t position, int64_t value) {
|
||||
SignalGuard signalGuard;
|
||||
auto info = exporter.circuit.getCircuitInfo()
|
||||
.asReader()
|
||||
.getInputs()[position];
|
||||
auto typeTransformer = getPythonTypeTransformer(info);
|
||||
auto result = exporter.circuit.prepareInput(
|
||||
typeTransformer({Tensor<int64_t>(value)}), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return clientlib::SharedScalarOrTensorData(
|
||||
std::move(result.value()));
|
||||
return ::concretelang::clientlib::SharedScalarOrTensorData{
|
||||
result.value()};
|
||||
})
|
||||
.def("export_tensor", [](clientlib::SimulatedValueExporter &exporter,
|
||||
.def("export_tensor", [](::concretelang::clientlib::SimulatedValueExporter
|
||||
&exporter,
|
||||
size_t position, std::vector<int64_t> values,
|
||||
std::vector<int64_t> shape) {
|
||||
outcome::checked<clientlib::ScalarOrTensorData, StringError> result =
|
||||
exporter.exportValue(values.data(), shape, position);
|
||||
SignalGuard signalGuard;
|
||||
std::vector<size_t> dimensions(shape.begin(), shape.end());
|
||||
auto info =
|
||||
exporter.circuit.getCircuitInfo().asReader().getInputs()[position];
|
||||
auto typeTransformer = getPythonTypeTransformer(info);
|
||||
auto result = exporter.circuit.prepareInput(
|
||||
typeTransformer({Tensor<int64_t>(values, dimensions)}), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return clientlib::SharedScalarOrTensorData(std::move(result.value()));
|
||||
return ::concretelang::clientlib::SharedScalarOrTensorData{
|
||||
result.value()};
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::ValueDecrypter>(m, "ValueDecrypter")
|
||||
.def_static("create",
|
||||
[](clientlib::KeySet &keySet,
|
||||
mlir::concretelang::ClientParameters &clientParameters) {
|
||||
return clientlib::ValueDecrypter(keySet, clientParameters);
|
||||
})
|
||||
.def("get_shape",
|
||||
[](clientlib::ValueDecrypter &decrypter, size_t position) {
|
||||
outcome::checked<std::vector<int64_t>, StringError> result =
|
||||
decrypter.getShape(position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
})
|
||||
.def("decrypt_scalar",
|
||||
[](clientlib::ValueDecrypter &decrypter, size_t position,
|
||||
clientlib::SharedScalarOrTensorData &value) {
|
||||
pybind11::class_<::concretelang::clientlib::ValueDecrypter>(m,
|
||||
"ValueDecrypter")
|
||||
.def_static(
|
||||
"create",
|
||||
[](::concretelang::clientlib::KeySet &keySet,
|
||||
::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return createValueDecrypter(keySet, clientParameters);
|
||||
})
|
||||
.def("decrypt",
|
||||
[](::concretelang::clientlib::ValueDecrypter &decrypter,
|
||||
size_t position,
|
||||
::concretelang::clientlib::SharedScalarOrTensorData &value) {
|
||||
SignalGuard signalGuard;
|
||||
|
||||
outcome::checked<int64_t, StringError> result =
|
||||
decrypter.decrypt<int64_t>(value.get(), position);
|
||||
|
||||
auto result =
|
||||
decrypter.circuit.processOutput(value.value, position);
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
})
|
||||
.def("decrypt_tensor",
|
||||
[](clientlib::ValueDecrypter &decrypter, size_t position,
|
||||
clientlib::SharedScalarOrTensorData &value) {
|
||||
SignalGuard signalGuard;
|
||||
|
||||
outcome::checked<std::vector<int64_t>, StringError> result =
|
||||
decrypter.decryptTensor<int64_t>(value.get(), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
return lambdaArgument{
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(
|
||||
mlir::concretelang::LambdaArgument{result.value()})};
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::SimulatedValueDecrypter>(
|
||||
pybind11::class_<::concretelang::clientlib::SimulatedValueDecrypter>(
|
||||
m, "SimulatedValueDecrypter")
|
||||
.def_static("create",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters) {
|
||||
return clientlib::SimulatedValueDecrypter(clientParameters);
|
||||
})
|
||||
.def("get_shape",
|
||||
[](clientlib::SimulatedValueDecrypter &decrypter, size_t position) {
|
||||
outcome::checked<std::vector<int64_t>, StringError> result =
|
||||
decrypter.getShape(position);
|
||||
.def_static(
|
||||
"create",
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
return createSimulatedValueDecrypter(clientParameters);
|
||||
})
|
||||
.def("decrypt",
|
||||
[](::concretelang::clientlib::SimulatedValueDecrypter &decrypter,
|
||||
size_t position,
|
||||
::concretelang::clientlib::SharedScalarOrTensorData &value) {
|
||||
SignalGuard signalGuard;
|
||||
|
||||
auto result =
|
||||
decrypter.circuit.processOutput(value.value, position);
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
})
|
||||
.def("decrypt_scalar",
|
||||
[](clientlib::SimulatedValueDecrypter &decrypter, size_t position,
|
||||
clientlib::SharedScalarOrTensorData &value) {
|
||||
outcome::checked<int64_t, StringError> result =
|
||||
decrypter.decrypt<int64_t>(value.get(), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
})
|
||||
.def("decrypt_tensor",
|
||||
[](clientlib::SimulatedValueDecrypter &decrypter, size_t position,
|
||||
clientlib::SharedScalarOrTensorData &value) {
|
||||
outcome::checked<std::vector<int64_t>, StringError> result =
|
||||
decrypter.decryptTensor<int64_t>(value.get(), position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
}
|
||||
|
||||
return result.value();
|
||||
return lambdaArgument{
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(
|
||||
mlir::concretelang::LambdaArgument{result.value()})};
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::PublicArguments,
|
||||
std::unique_ptr<clientlib::PublicArguments>>(
|
||||
pybind11::class_<::concretelang::clientlib::PublicArguments,
|
||||
std::unique_ptr<::concretelang::clientlib::PublicArguments>>(
|
||||
m, "PublicArguments")
|
||||
.def_static(
|
||||
"create",
|
||||
[](const mlir::concretelang::ClientParameters &clientParameters,
|
||||
std::vector<clientlib::SharedScalarOrTensorData> &buffers) {
|
||||
return clientlib::PublicArguments(clientParameters, buffers);
|
||||
[](const ::concretelang::clientlib::ClientParameters
|
||||
&clientParameters,
|
||||
std::vector<::concretelang::clientlib::SharedScalarOrTensorData>
|
||||
&buffers) {
|
||||
std::vector<TransportValue> vals;
|
||||
for (auto buf : buffers) {
|
||||
vals.push_back(buf.value);
|
||||
}
|
||||
return ::concretelang::clientlib::PublicArguments{vals};
|
||||
})
|
||||
.def_static(
|
||||
"deserialize",
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const pybind11::bytes &buffer) {
|
||||
return publicArgumentsUnserialize(clientParameters, buffer);
|
||||
})
|
||||
.def_static("deserialize",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters,
|
||||
const pybind11::bytes &buffer) {
|
||||
return publicArgumentsUnserialize(clientParameters, buffer);
|
||||
})
|
||||
.def("serialize", [](clientlib::PublicArguments &publicArgument) {
|
||||
return pybind11::bytes(publicArgumentsSerialize(publicArgument));
|
||||
});
|
||||
pybind11::class_<clientlib::PublicResult>(m, "PublicResult")
|
||||
.def_static("deserialize",
|
||||
[](mlir::concretelang::ClientParameters &clientParameters,
|
||||
const pybind11::bytes &buffer) {
|
||||
return publicResultUnserialize(clientParameters, buffer);
|
||||
})
|
||||
.def("serialize",
|
||||
[](clientlib::PublicResult &publicResult) {
|
||||
[](::concretelang::clientlib::PublicArguments &publicArgument) {
|
||||
return pybind11::bytes(publicArgumentsSerialize(publicArgument));
|
||||
});
|
||||
pybind11::class_<::concretelang::clientlib::PublicResult>(m, "PublicResult")
|
||||
.def_static(
|
||||
"deserialize",
|
||||
[](::concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const pybind11::bytes &buffer) {
|
||||
return publicResultUnserialize(clientParameters, buffer);
|
||||
})
|
||||
.def("serialize",
|
||||
[](::concretelang::clientlib::PublicResult &publicResult) {
|
||||
return pybind11::bytes(publicResultSerialize(publicResult));
|
||||
})
|
||||
.def("n_values",
|
||||
[](const clientlib::PublicResult &publicResult) {
|
||||
return publicResult.buffers.size();
|
||||
[](const ::concretelang::clientlib::PublicResult &publicResult) {
|
||||
return publicResult.values.size();
|
||||
})
|
||||
.def("get_value",
|
||||
[](clientlib::PublicResult &publicResult, size_t position) {
|
||||
outcome::checked<clientlib::SharedScalarOrTensorData, StringError>
|
||||
result = publicResult.getValue(position);
|
||||
|
||||
if (result.has_error()) {
|
||||
throw std::runtime_error(result.error().mesg);
|
||||
[](::concretelang::clientlib::PublicResult &publicResult,
|
||||
size_t position) {
|
||||
if (position >= publicResult.values.size()) {
|
||||
throw std::runtime_error("Failed to get public result value.");
|
||||
}
|
||||
|
||||
return result.value();
|
||||
return ::concretelang::clientlib::SharedScalarOrTensorData{
|
||||
publicResult.values[position]};
|
||||
});
|
||||
|
||||
pybind11::class_<clientlib::EvaluationKeys>(m, "EvaluationKeys")
|
||||
pybind11::class_<::concretelang::clientlib::EvaluationKeys>(m,
|
||||
"EvaluationKeys")
|
||||
.def_static("deserialize",
|
||||
[](const pybind11::bytes &buffer) {
|
||||
return evaluationKeysUnserialize(buffer);
|
||||
})
|
||||
.def("serialize", [](clientlib::EvaluationKeys &evaluationKeys) {
|
||||
return pybind11::bytes(evaluationKeysSerialize(evaluationKeys));
|
||||
});
|
||||
.def("serialize",
|
||||
[](::concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
return pybind11::bytes(evaluationKeysSerialize(evaluationKeys));
|
||||
});
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
.def_static("from_tensor_u8",
|
||||
|
||||
@@ -4,79 +4,19 @@
|
||||
// for license information.
|
||||
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "capnp/message.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/Bindings/Python/CompilerEngine.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Compat.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Common/Values.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
#include "concretelang/Support/Jit.h"
|
||||
|
||||
#define GET_OR_THROW_LLVM_EXPECTED(VARNAME, EXPECTED) \
|
||||
auto VARNAME = EXPECTED; \
|
||||
if (auto err = VARNAME.takeError()) { \
|
||||
throw std::runtime_error(llvm::toString(std::move(err))); \
|
||||
}
|
||||
|
||||
// JIT Support bindings ///////////////////////////////////////////////////////
|
||||
|
||||
MLIR_CAPI_EXPORTED JITSupport_Py jit_support(std::string runtimeLibPath) {
|
||||
auto opt = runtimeLibPath.empty()
|
||||
? std::nullopt
|
||||
: std::optional<std::string>(runtimeLibPath);
|
||||
return JITSupport_Py{mlir::concretelang::JITSupport(opt)};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile(JITSupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
support.support.compile(module, options));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::JitCompilationResult>
|
||||
jit_compile_module(
|
||||
JITSupport_Py support, mlir::ModuleOp module,
|
||||
mlir::concretelang::CompilationOptions options,
|
||||
std::shared_ptr<mlir::concretelang::CompilationContext> cctx) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
support.support.compile(module, cctx, options));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
jit_load_client_parameters(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(clientParameters,
|
||||
support.support.loadClientParameters(result));
|
||||
return *clientParameters;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::CompilationFeedback
|
||||
jit_load_compilation_feedback(
|
||||
JITSupport_Py support, mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationFeedback,
|
||||
support.support.loadCompilationFeedback(result));
|
||||
return *compilationFeedback;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::shared_ptr<mlir::concretelang::JITLambda>
|
||||
jit_load_server_lambda(JITSupport_Py support,
|
||||
mlir::concretelang::JitCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
|
||||
support.support.loadServerLambda(result));
|
||||
return *serverLambda;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
jit_server_call(JITSupport_Py support, mlir::concretelang::JITLambda &lambda,
|
||||
concretelang::clientlib::PublicArguments &args,
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(publicResult, lambda.call(args, evaluationKeys));
|
||||
return std::move(*publicResult);
|
||||
}
|
||||
|
||||
// Library Support bindings ///////////////////////////////////////////////////
|
||||
MLIR_CAPI_EXPORTED LibrarySupport_Py
|
||||
@@ -86,15 +26,17 @@ library_support(const char *outputPath, const char *runtimeLibraryPath,
|
||||
bool generateCppHeader) {
|
||||
return LibrarySupport_Py{mlir::concretelang::LibrarySupport(
|
||||
outputPath, runtimeLibraryPath, generateSharedLib, generateStaticLib,
|
||||
generateClientParameters, generateCompilationFeedback,
|
||||
generateCppHeader)};
|
||||
generateClientParameters, generateCompilationFeedback)};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<mlir::concretelang::LibraryCompilationResult>
|
||||
library_compile(LibrarySupport_Py support, const char *module,
|
||||
mlir::concretelang::CompilationOptions options) {
|
||||
llvm::SourceMgr sm;
|
||||
sm.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(module),
|
||||
llvm::SMLoc());
|
||||
GET_OR_THROW_LLVM_EXPECTED(compilationResult,
|
||||
support.support.compile(module, options));
|
||||
support.support.compile(sm, options));
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
@@ -108,7 +50,7 @@ library_compile_module(
|
||||
return std::move(*compilationResult);
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::ClientParameters
|
||||
library_load_client_parameters(
|
||||
LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
@@ -127,11 +69,11 @@ library_load_compilation_feedback(
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::serverlib::ServerLambda
|
||||
library_load_server_lambda(
|
||||
LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &result) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(serverLambda,
|
||||
support.support.loadServerLambda(result));
|
||||
library_load_server_lambda(LibrarySupport_Py support,
|
||||
mlir::concretelang::LibraryCompilationResult &result,
|
||||
bool useSimulation) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
serverLambda, support.support.loadServerLambda(result, useSimulation));
|
||||
return *serverLambda;
|
||||
}
|
||||
|
||||
@@ -160,8 +102,8 @@ library_get_shared_lib_path(LibrarySupport_Py support) {
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
library_get_client_parameters_path(LibrarySupport_Py support) {
|
||||
return support.support.getClientParametersPath();
|
||||
library_get_program_info_path(LibrarySupport_Py support) {
|
||||
return support.support.getProgramInfoPath();
|
||||
}
|
||||
|
||||
// Client Support bindings ///////////////////////////////////////////////////
|
||||
@@ -170,157 +112,303 @@ MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
key_set(concretelang::clientlib::ClientParameters clientParameters,
|
||||
std::optional<concretelang::clientlib::KeySetCache> cache,
|
||||
uint64_t seedMsb, uint64_t seedLsb) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
ks, (mlir::concretelang::LambdaSupport<int, int>::keySet(
|
||||
clientParameters, cache, seedMsb, seedLsb)));
|
||||
return std::move(*ks);
|
||||
if (cache.has_value()) {
|
||||
GET_OR_THROW_RESULT(Keyset keyset,
|
||||
(*cache).keysetCache.getKeyset(
|
||||
clientParameters.programInfo.asReader().getKeyset(),
|
||||
seedMsb, seedLsb));
|
||||
concretelang::clientlib::KeySet output{keyset};
|
||||
return std::make_unique<concretelang::clientlib::KeySet>(std::move(output));
|
||||
} else {
|
||||
__uint128_t seed = seedMsb;
|
||||
seed <<= 64;
|
||||
seed += seedLsb;
|
||||
auto csprng = concretelang::csprng::ConcreteCSPRNG(seed);
|
||||
auto keyset =
|
||||
Keyset(clientParameters.programInfo.asReader().getKeyset(), csprng);
|
||||
concretelang::clientlib::KeySet output{keyset};
|
||||
return std::make_unique<concretelang::clientlib::KeySet>(std::move(output));
|
||||
}
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
encrypt_arguments(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *> args) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
publicArguments,
|
||||
(mlir::concretelang::LambdaSupport<int, int>::exportArguments(
|
||||
clientParameters, keySet, args)));
|
||||
return std::move(*publicArguments);
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(), keySet.keyset.client,
|
||||
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
|
||||
false);
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto circuit = maybeProgram.value()
|
||||
.getClientCircuit(clientParameters.programInfo.asReader()
|
||||
.getCircuits()[0]
|
||||
.getName())
|
||||
.value();
|
||||
std::vector<TransportValue> output;
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
auto info =
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getInputs()[i];
|
||||
auto typeTransformer = getPythonTypeTransformer(info);
|
||||
auto input = typeTransformer(args[i]->value);
|
||||
auto maybePrepared = circuit.prepareInput(input, i);
|
||||
|
||||
if (maybePrepared.has_failure()) {
|
||||
throw std::runtime_error(maybePrepared.as_failure().error().mesg);
|
||||
}
|
||||
output.push_back(maybePrepared.value());
|
||||
}
|
||||
concretelang::clientlib::PublicArguments publicArgs{output};
|
||||
return std::make_unique<concretelang::clientlib::PublicArguments>(
|
||||
std::move(publicArgs));
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
decrypt_result(concretelang::clientlib::KeySet &keySet,
|
||||
decrypt_result(concretelang::clientlib::ClientParameters clientParameters,
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::PublicResult &publicResult) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
result, mlir::concretelang::typedResult<
|
||||
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
|
||||
keySet, publicResult));
|
||||
lambdaArgument result_{std::move(*result)};
|
||||
return result_;
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(), keySet.keyset.client,
|
||||
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
|
||||
false);
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
if (publicResult.values.size() != 1) {
|
||||
throw std::runtime_error("Tried to decrypt with wrong arity.");
|
||||
}
|
||||
auto circuit = maybeProgram.value()
|
||||
.getClientCircuit(clientParameters.programInfo.asReader()
|
||||
.getCircuits()[0]
|
||||
.getName())
|
||||
.value();
|
||||
auto maybeProcessed = circuit.processOutput(publicResult.values[0], 0);
|
||||
if (maybeProcessed.has_failure()) {
|
||||
throw std::runtime_error(maybeProcessed.as_failure().error().mesg);
|
||||
}
|
||||
|
||||
mlir::concretelang::LambdaArgument out{maybeProcessed.value()};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicArguments>
|
||||
publicArgumentsUnserialize(
|
||||
mlir::concretelang::ClientParameters &clientParameters,
|
||||
concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
auto argsOrError = concretelang::clientlib::PublicArguments::unserialize(
|
||||
clientParameters, istream);
|
||||
if (!argsOrError) {
|
||||
throw std::runtime_error(argsOrError.error().mesg);
|
||||
auto publicArgumentsProto = Message<concreteprotocol::PublicArguments>();
|
||||
if (publicArgumentsProto.readBinaryFromString(buffer).has_failure()) {
|
||||
throw std::runtime_error("Failed to deserialize public arguments.");
|
||||
}
|
||||
return std::move(argsOrError.value());
|
||||
std::vector<TransportValue> values;
|
||||
for (auto arg : publicArgumentsProto.asReader().getArgs()) {
|
||||
values.push_back(arg);
|
||||
}
|
||||
concretelang::clientlib::PublicArguments output{values};
|
||||
return std::make_unique<concretelang::clientlib::PublicArguments>(
|
||||
std::move(output));
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string publicArgumentsSerialize(
|
||||
concretelang::clientlib::PublicArguments &publicArguments) {
|
||||
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
auto voidOrError = publicArguments.serialize(buffer);
|
||||
if (!voidOrError) {
|
||||
throw std::runtime_error(voidOrError.error().mesg);
|
||||
auto publicArgumentsProto = Message<concreteprotocol::PublicArguments>();
|
||||
auto argBuilder =
|
||||
publicArgumentsProto.asBuilder().initArgs(publicArguments.values.size());
|
||||
for (size_t i = 0; i < publicArguments.values.size(); i++) {
|
||||
argBuilder.setWithCaveats(i, publicArguments.values[i].asReader());
|
||||
}
|
||||
return buffer.str();
|
||||
auto maybeBuffer = publicArgumentsProto.writeBinaryToString();
|
||||
if (maybeBuffer.has_failure()) {
|
||||
throw std::runtime_error("Failed to serialize public arguments.");
|
||||
}
|
||||
return maybeBuffer.value();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::PublicResult>
|
||||
publicResultUnserialize(mlir::concretelang::ClientParameters &clientParameters,
|
||||
const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
auto publicResultOrError = concretelang::clientlib::PublicResult::unserialize(
|
||||
clientParameters, istream);
|
||||
if (!publicResultOrError) {
|
||||
throw std::runtime_error(publicResultOrError.error().mesg);
|
||||
publicResultUnserialize(
|
||||
concretelang::clientlib::ClientParameters &clientParameters,
|
||||
const std::string &buffer) {
|
||||
auto publicResultsProto = Message<concreteprotocol::PublicResults>();
|
||||
if (publicResultsProto.readBinaryFromString(buffer).has_failure()) {
|
||||
throw std::runtime_error("Failed to deserialize public results.");
|
||||
}
|
||||
return std::move(publicResultOrError.value());
|
||||
std::vector<TransportValue> values;
|
||||
for (auto res : publicResultsProto.asReader().getResults()) {
|
||||
values.push_back(res);
|
||||
}
|
||||
concretelang::clientlib::PublicResult output{values};
|
||||
return std::make_unique<concretelang::clientlib::PublicResult>(
|
||||
std::move(output));
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
publicResultSerialize(concretelang::clientlib::PublicResult &publicResult) {
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
auto voidOrError = publicResult.serialize(buffer);
|
||||
if (!voidOrError) {
|
||||
throw std::runtime_error(voidOrError.error().mesg);
|
||||
std::string buffer;
|
||||
auto publicResultsProto = Message<concreteprotocol::PublicResults>();
|
||||
auto resBuilder =
|
||||
publicResultsProto.asBuilder().initResults(publicResult.values.size());
|
||||
for (size_t i = 0; i < publicResult.values.size(); i++) {
|
||||
resBuilder.setWithCaveats(i, publicResult.values[i].asReader());
|
||||
}
|
||||
return buffer.str();
|
||||
auto maybeBuffer = publicResultsProto.writeBinaryToString();
|
||||
if (maybeBuffer.has_failure()) {
|
||||
throw std::runtime_error("Failed to serialize public results.");
|
||||
}
|
||||
return maybeBuffer.value();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::EvaluationKeys
|
||||
evaluationKeysUnserialize(const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
|
||||
concretelang::clientlib::EvaluationKeys evaluationKeys =
|
||||
concretelang::clientlib::readEvaluationKeys(istream);
|
||||
|
||||
if (istream.fail()) {
|
||||
throw std::runtime_error("Cannot read evaluation keys");
|
||||
auto serverKeysetProto = Message<concreteprotocol::ServerKeyset>();
|
||||
auto maybeError = serverKeysetProto.readBinaryFromString(
|
||||
buffer, capnp::ReaderOptions{7000000000, 64});
|
||||
if (maybeError.has_failure()) {
|
||||
throw std::runtime_error("Failed to deserialize server keyset." +
|
||||
maybeError.as_failure().error().mesg);
|
||||
}
|
||||
|
||||
return evaluationKeys;
|
||||
auto serverKeyset =
|
||||
concretelang::keysets::ServerKeyset::fromProto(serverKeysetProto);
|
||||
concretelang::clientlib::EvaluationKeys output{serverKeyset};
|
||||
return output;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string evaluationKeysSerialize(
|
||||
concretelang::clientlib::EvaluationKeys &evaluationKeys) {
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
concretelang::clientlib::operator<<(buffer, evaluationKeys);
|
||||
return buffer.str();
|
||||
auto serverKeysetProto = evaluationKeys.keyset.toProto();
|
||||
auto maybeBuffer = serverKeysetProto.writeBinaryToString();
|
||||
if (maybeBuffer.has_failure()) {
|
||||
throw std::runtime_error("Failed to serialize evaluation keys.");
|
||||
}
|
||||
return maybeBuffer.value();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::unique_ptr<concretelang::clientlib::KeySet>
|
||||
keySetUnserialize(const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
|
||||
std::unique_ptr<concretelang::clientlib::KeySet> keySet =
|
||||
concretelang::clientlib::readKeySet(istream);
|
||||
|
||||
if (istream.fail() || keySet.get() == nullptr) {
|
||||
throw std::runtime_error("Cannot read key set");
|
||||
auto keysetProto = Message<concreteprotocol::Keyset>();
|
||||
auto maybeError = keysetProto.readBinaryFromString(
|
||||
buffer, capnp::ReaderOptions{7000000000, 64});
|
||||
if (maybeError.has_failure()) {
|
||||
throw std::runtime_error("Failed to deserialize keyset." +
|
||||
maybeError.as_failure().error().mesg);
|
||||
}
|
||||
|
||||
return keySet;
|
||||
auto keyset = concretelang::keysets::Keyset::fromProto(keysetProto);
|
||||
concretelang::clientlib::KeySet output{keyset};
|
||||
return std::make_unique<concretelang::clientlib::KeySet>(std::move(output));
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
keySetSerialize(concretelang::clientlib::KeySet &keySet) {
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
concretelang::clientlib::operator<<(buffer, keySet);
|
||||
return buffer.str();
|
||||
auto keysetProto = keySet.keyset.toProto();
|
||||
auto maybeBuffer = keysetProto.writeBinaryToString();
|
||||
if (maybeBuffer.has_failure()) {
|
||||
throw std::runtime_error("Failed to serialize keys.");
|
||||
}
|
||||
return maybeBuffer.value();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::SharedScalarOrTensorData
|
||||
valueUnserialize(const std::string &buffer) {
|
||||
std::stringstream istream(buffer);
|
||||
|
||||
auto value = concretelang::clientlib::unserializeScalarOrTensorData(istream);
|
||||
if (istream.fail() || value.has_error()) {
|
||||
throw std::runtime_error("Cannot read data");
|
||||
auto inner = TransportValue();
|
||||
if (inner.readBinaryFromString(buffer).has_failure()) {
|
||||
throw std::runtime_error("Failed to deserialize Value");
|
||||
}
|
||||
|
||||
return concretelang::clientlib::SharedScalarOrTensorData(
|
||||
std::move(value.value()));
|
||||
return {inner};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
valueSerialize(const concretelang::clientlib::SharedScalarOrTensorData &value) {
|
||||
std::ostringstream buffer(std::ios::binary);
|
||||
serializeScalarOrTensorData(value.get(), buffer);
|
||||
return buffer.str();
|
||||
auto maybeString = value.value.writeBinaryToString();
|
||||
if (maybeString.has_failure()) {
|
||||
throw std::runtime_error("Failed to serialize Value");
|
||||
}
|
||||
return maybeString.value();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::concretelang::ClientParameters
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::ValueExporter createValueExporter(
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(), keySet.keyset.client,
|
||||
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
|
||||
false);
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getName());
|
||||
return ::concretelang::clientlib::ValueExporter{maybeCircuit.value()};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::SimulatedValueExporter
|
||||
createSimulatedValueExporter(
|
||||
concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo, ::concretelang::keysets::ClientKeyset(),
|
||||
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
|
||||
true);
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getName());
|
||||
return ::concretelang::clientlib::SimulatedValueExporter{
|
||||
maybeCircuit.value()};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::ValueDecrypter createValueDecrypter(
|
||||
concretelang::clientlib::KeySet &keySet,
|
||||
concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(), keySet.keyset.client,
|
||||
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
|
||||
false);
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getName());
|
||||
return ::concretelang::clientlib::ValueDecrypter{maybeCircuit.value()};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::SimulatedValueDecrypter
|
||||
createSimulatedValueDecrypter(
|
||||
concretelang::clientlib::ClientParameters &clientParameters) {
|
||||
|
||||
auto maybeProgram = ::concretelang::clientlib::ClientProgram::create(
|
||||
clientParameters.programInfo.asReader(),
|
||||
::concretelang::keysets::ClientKeyset(),
|
||||
std::make_shared<CSPRNG>(::concretelang::csprng::ConcreteCSPRNG(0)),
|
||||
true);
|
||||
if (maybeProgram.has_failure()) {
|
||||
throw std::runtime_error(maybeProgram.as_failure().error().mesg);
|
||||
}
|
||||
auto maybeCircuit = maybeProgram.value().getClientCircuit(
|
||||
clientParameters.programInfo.asReader().getCircuits()[0].getName());
|
||||
return ::concretelang::clientlib::SimulatedValueDecrypter{
|
||||
maybeCircuit.value()};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED concretelang::clientlib::ClientParameters
|
||||
clientParametersUnserialize(const std::string &json) {
|
||||
GET_OR_THROW_LLVM_EXPECTED(
|
||||
clientParams,
|
||||
llvm::json::parse<mlir::concretelang::ClientParameters>(json));
|
||||
return clientParams.get();
|
||||
auto programInfo = Message<concreteprotocol::ProgramInfo>();
|
||||
if (programInfo.readJsonFromString(json).has_failure()) {
|
||||
throw std::runtime_error("Failed to deserialize client parameters");
|
||||
}
|
||||
return concretelang::clientlib::ClientParameters{programInfo, {}, {}, {}, {}};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string
|
||||
clientParametersSerialize(mlir::concretelang::ClientParameters ¶ms) {
|
||||
llvm::json::Value value(params);
|
||||
std::string jsonParams;
|
||||
llvm::raw_string_ostream buffer(jsonParams);
|
||||
buffer << value;
|
||||
return jsonParams;
|
||||
clientParametersSerialize(concretelang::clientlib::ClientParameters ¶ms) {
|
||||
auto maybeJson = params.programInfo.writeJsonToString();
|
||||
if (maybeJson.has_failure()) {
|
||||
throw std::runtime_error("Failed to serialize client parameters");
|
||||
}
|
||||
return maybeJson.value();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED void terminateDataflowParallelization() { _dfr_terminate(); }
|
||||
@@ -350,283 +438,180 @@ MLIR_CAPI_EXPORTED std::string roundTrip(const char *module) {
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
|
||||
return lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int8_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int16_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int32_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int64_t>>>();
|
||||
}
|
||||
|
||||
template <typename T, typename R>
|
||||
MLIR_CAPI_EXPORTED std::vector<R> copyTensorLambdaArgumentTo64bitsvector(
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<T>> *tensor) {
|
||||
auto numElements = tensor->getNumElements();
|
||||
if (!numElements) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(std::move(numElements.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<R> res;
|
||||
res.reserve(*numElements);
|
||||
T *data = tensor->getValue();
|
||||
for (size_t i = 0; i < *numElements; i++) {
|
||||
res.push_back(data[i]);
|
||||
}
|
||||
return res;
|
||||
return !lambda_arg.ptr->value.isScalar();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::vector<uint64_t>
|
||||
lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
|
||||
if (!sizeOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(sizeOrErr.takeError());
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
|
||||
return data;
|
||||
if (auto tensor = lambda_arg.ptr->value.getTensor<uint8_t>(); tensor) {
|
||||
Tensor<uint64_t> out = (Tensor<uint64_t>)tensor.value();
|
||||
return out.values;
|
||||
} else if (auto tensor = lambda_arg.ptr->value.getTensor<uint16_t>();
|
||||
tensor) {
|
||||
Tensor<uint64_t> out = (Tensor<uint64_t>)tensor.value();
|
||||
return out.values;
|
||||
} else if (auto tensor = lambda_arg.ptr->value.getTensor<uint32_t>();
|
||||
tensor) {
|
||||
Tensor<uint64_t> out = (Tensor<uint64_t>)tensor.value();
|
||||
return out.values;
|
||||
} else if (auto tensor = lambda_arg.ptr->value.getTensor<uint64_t>();
|
||||
tensor) {
|
||||
return tensor.value().values;
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector<uint8_t, uint64_t>(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector<uint16_t, uint64_t>(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector<uint32_t, uint64_t>(arg);
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::vector<int64_t>
|
||||
lambdaArgumentGetSignedTensorData(lambdaArgument &lambda_arg) {
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int64_t>>>()) {
|
||||
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
|
||||
if (!sizeOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(sizeOrErr.takeError());
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<int64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
|
||||
return data;
|
||||
if (auto tensor = lambda_arg.ptr->value.getTensor<int8_t>(); tensor) {
|
||||
Tensor<int64_t> out = (Tensor<int64_t>)tensor.value();
|
||||
return out.values;
|
||||
} else if (auto tensor = lambda_arg.ptr->value.getTensor<int16_t>(); tensor) {
|
||||
Tensor<int64_t> out = (Tensor<int64_t>)tensor.value();
|
||||
return out.values;
|
||||
} else if (auto tensor = lambda_arg.ptr->value.getTensor<int32_t>(); tensor) {
|
||||
Tensor<int64_t> out = (Tensor<int64_t>)tensor.value();
|
||||
return out.values;
|
||||
} else if (auto tensor = lambda_arg.ptr->value.getTensor<int64_t>(); tensor) {
|
||||
return tensor.value().values;
|
||||
} else {
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int8_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector<int8_t, int64_t>(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int16_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector<int16_t, int64_t>(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int32_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector<int32_t, int64_t>(arg);
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int8_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int16_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int32_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int64_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<(u)int{8,16,32,64}_t>>");
|
||||
std::vector<size_t> dims = lambda_arg.ptr->value.getDimensions();
|
||||
return {dims.begin(), dims.end()};
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {
|
||||
auto ptr = lambda_arg.ptr;
|
||||
return ptr->isa<mlir::concretelang::IntLambdaArgument<uint64_t>>() ||
|
||||
ptr->isa<mlir::concretelang::IntLambdaArgument<int64_t>>();
|
||||
return lambda_arg.ptr->value.isScalar();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsSigned(lambdaArgument &lambda_arg) {
|
||||
auto ptr = lambda_arg.ptr;
|
||||
return ptr->isa<mlir::concretelang::IntLambdaArgument<int8_t>>() ||
|
||||
ptr->isa<mlir::concretelang::IntLambdaArgument<int16_t>>() ||
|
||||
ptr->isa<mlir::concretelang::IntLambdaArgument<int32_t>>() ||
|
||||
ptr->isa<mlir::concretelang::IntLambdaArgument<int64_t>>() ||
|
||||
ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int8_t>>>() ||
|
||||
ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int16_t>>>() ||
|
||||
ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int32_t>>>() ||
|
||||
ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int64_t>>>();
|
||||
;
|
||||
return lambda_arg.ptr->value.isSigned();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t
|
||||
lambdaArgumentGetScalar(lambdaArgument &lambda_arg) {
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t> *arg =
|
||||
lambda_arg.ptr
|
||||
->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
if (arg == nullptr) {
|
||||
if (lambda_arg.ptr->value.isScalar() &&
|
||||
lambda_arg.ptr->value.hasElementType<uint64_t>()) {
|
||||
return lambda_arg.ptr->value.getTensor<uint64_t>()->values[0];
|
||||
} else {
|
||||
throw std::invalid_argument("LambdaArgument isn't a scalar, should "
|
||||
"be an IntLambdaArgument<uint64_t>");
|
||||
}
|
||||
return arg->getValue();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED int64_t
|
||||
lambdaArgumentGetSignedScalar(lambdaArgument &lambda_arg) {
|
||||
mlir::concretelang::IntLambdaArgument<int64_t> *arg =
|
||||
lambda_arg.ptr
|
||||
->dyn_cast<mlir::concretelang::IntLambdaArgument<int64_t>>();
|
||||
if (arg == nullptr) {
|
||||
if (lambda_arg.ptr->value.isScalar() &&
|
||||
lambda_arg.ptr->value.hasElementType<int64_t>()) {
|
||||
return lambda_arg.ptr->value.getTensor<int64_t>()->values[0];
|
||||
} else {
|
||||
throw std::invalid_argument("LambdaArgument isn't a scalar, should "
|
||||
"be an IntLambdaArgument<int64_t>");
|
||||
}
|
||||
return arg->getValue();
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
|
||||
std::vector<uint8_t> data, std::vector<int64_t> dimensions) {
|
||||
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
|
||||
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<uint8_t>(data, dims))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>(data, dimensions)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorI8(
|
||||
std::vector<int8_t> data, std::vector<int64_t> dimensions) {
|
||||
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<int8_t>(data, dims))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int8_t>>>(data, dimensions)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16(
|
||||
std::vector<uint16_t> data, std::vector<int64_t> dimensions) {
|
||||
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<uint16_t>(data, dims))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>(data, dimensions)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorI16(
|
||||
std::vector<int16_t> data, std::vector<int64_t> dimensions) {
|
||||
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<int16_t>(data, dims))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int16_t>>>(data, dimensions)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32(
|
||||
std::vector<uint32_t> data, std::vector<int64_t> dimensions) {
|
||||
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<uint32_t>(data, dims))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>(data, dimensions)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorI32(
|
||||
std::vector<int32_t> data, std::vector<int64_t> dimensions) {
|
||||
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<int32_t>(data, dims))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int32_t>>>(data, dimensions)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64(
|
||||
std::vector<uint64_t> data, std::vector<int64_t> dimensions) {
|
||||
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<uint64_t>(data, dims))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>(data, dimensions)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorI64(
|
||||
std::vector<int64_t> data, std::vector<int64_t> dimensions) {
|
||||
std::vector<size_t> dims(dimensions.begin(), dimensions.end());
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<int64_t>(data, dims))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<int64_t>>>(data, dimensions)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) {
|
||||
auto val = Value{((Tensor<int64_t>)Tensor<uint64_t>(scalar))};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument scalar_arg{
|
||||
std::make_shared<mlir::concretelang::IntLambdaArgument<uint64_t>>(
|
||||
scalar)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return scalar_arg;
|
||||
}
|
||||
|
||||
MLIR_CAPI_EXPORTED lambdaArgument
|
||||
lambdaArgumentFromSignedScalar(int64_t scalar) {
|
||||
auto val = Value{Tensor<int64_t>(scalar)};
|
||||
mlir::concretelang::LambdaArgument out{val};
|
||||
lambdaArgument scalar_arg{
|
||||
std::make_shared<mlir::concretelang::IntLambdaArgument<int64_t>>(scalar)};
|
||||
std::make_shared<mlir::concretelang::LambdaArgument>(std::move(out))};
|
||||
return scalar_arg;
|
||||
}
|
||||
|
||||
@@ -3,16 +3,17 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang-c/Dialect/FHE.h"
|
||||
#include "concretelang-c/Dialect/FHELinalg.h"
|
||||
#include "concretelang-c/Dialect/Tracing.h"
|
||||
#include "concretelang/Bindings/Python/CompilerAPIModule.h"
|
||||
#include "concretelang/Bindings/Python/DialectModules.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
#include "mlir-c/Bindings/Python/Interop.h"
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/RegisterEverything.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/IR/DialectRegistry.h"
|
||||
|
||||
#include "llvm-c/ErrorHandling.h"
|
||||
@@ -21,6 +22,19 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace mlir::concretelang::FHE;
|
||||
using namespace mlir::concretelang::FHELinalg;
|
||||
using namespace mlir::concretelang::Tracing;
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHE, fhe);
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHE, fhe, FHEDialect)
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg);
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg, FHELinalgDialect)
|
||||
|
||||
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(TRACING, tracing);
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Tracing, tracing, TracingDialect)
|
||||
|
||||
PYBIND11_MODULE(_concretelang, m) {
|
||||
m.doc() = "Concretelang Python Native Extension";
|
||||
llvm::sys::PrintStackTraceOnErrorSignal(/*argv=*/"");
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang-c/Dialect/FHE.h"
|
||||
#include "concretelang/Bindings/Python/DialectModules.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir/Bindings/Python/PybindAdaptors.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
@@ -17,18 +18,47 @@
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
using namespace mlir::concretelang;
|
||||
using namespace mlir::concretelang::FHE;
|
||||
using namespace mlir::python::adaptors;
|
||||
|
||||
typedef struct {
|
||||
MlirType type;
|
||||
bool isError;
|
||||
} MlirTypeOrError;
|
||||
|
||||
template <typename T>
|
||||
MlirTypeOrError IntegerTypeGetChecked(MlirContext ctx, unsigned width) {
|
||||
MlirTypeOrError type = {{NULL}, false};
|
||||
auto catchError = [&]() -> mlir::InFlightDiagnostic {
|
||||
type.isError = true;
|
||||
mlir::DiagnosticEngine &engine = unwrap(ctx)->getDiagEngine();
|
||||
// The goal here is to make getChecked working, but we don't want the CAPI
|
||||
// to stop execution due to an error, and leave the error handling logic to
|
||||
// the user of the CAPI
|
||||
return engine.emit(mlir::UnknownLoc::get(unwrap(ctx)),
|
||||
mlir::DiagnosticSeverity::Warning);
|
||||
};
|
||||
T integerType = T::getChecked(catchError, unwrap(ctx), width);
|
||||
if (type.isError) {
|
||||
return type;
|
||||
}
|
||||
type.type = wrap(integerType);
|
||||
return type;
|
||||
}
|
||||
|
||||
/// Populate the fhe python module.
|
||||
void mlir::concretelang::python::populateDialectFHESubmodule(
|
||||
pybind11::module &m) {
|
||||
m.doc() = "FHE dialect Python native extension";
|
||||
|
||||
mlir_type_subclass(m, "EncryptedIntegerType", fheTypeIsAnEncryptedIntegerType)
|
||||
mlir_type_subclass(m, "EncryptedIntegerType",
|
||||
[](MlirType type) {
|
||||
return unwrap(type).isa<EncryptedUnsignedIntegerType>();
|
||||
})
|
||||
.def_classmethod("get", [](pybind11::object cls, MlirContext ctx,
|
||||
unsigned width) {
|
||||
MlirTypeOrError typeOrError =
|
||||
fheEncryptedIntegerTypeGetChecked(ctx, width);
|
||||
IntegerTypeGetChecked<EncryptedUnsignedIntegerType>(ctx, width);
|
||||
if (typeOrError.isError) {
|
||||
throw std::invalid_argument("can't create eint with the given width");
|
||||
}
|
||||
@@ -36,11 +66,13 @@ void mlir::concretelang::python::populateDialectFHESubmodule(
|
||||
});
|
||||
|
||||
mlir_type_subclass(m, "EncryptedSignedIntegerType",
|
||||
fheTypeIsAnEncryptedSignedIntegerType)
|
||||
[](MlirType type) {
|
||||
return unwrap(type).isa<EncryptedSignedIntegerType>();
|
||||
})
|
||||
.def_classmethod(
|
||||
"get", [](pybind11::object cls, MlirContext ctx, unsigned width) {
|
||||
MlirTypeOrError typeOrError =
|
||||
fheEncryptedSignedIntegerTypeGetChecked(ctx, width);
|
||||
IntegerTypeGetChecked<EncryptedSignedIntegerType>(ctx, width);
|
||||
if (typeOrError.isError) {
|
||||
throw std::invalid_argument(
|
||||
"can't create esint with the given width");
|
||||
|
||||
@@ -25,13 +25,10 @@ from .compilation_feedback import CompilationFeedback
|
||||
from .key_set import KeySet
|
||||
from .public_result import PublicResult
|
||||
from .public_arguments import PublicArguments
|
||||
from .jit_compilation_result import JITCompilationResult
|
||||
from .jit_lambda import JITLambda
|
||||
from .lambda_argument import LambdaArgument
|
||||
from .library_compilation_result import LibraryCompilationResult
|
||||
from .library_lambda import LibraryLambda
|
||||
from .client_support import ClientSupport
|
||||
from .jit_support import JITSupport
|
||||
from .library_support import LibrarySupport
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
from .value import Value
|
||||
|
||||
@@ -44,8 +44,8 @@ class ClientSupport(WrapperCpp):
|
||||
)
|
||||
super().__init__(client_support)
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
@staticmethod
|
||||
def new() -> "ClientSupport":
|
||||
"""Build a ClientSupport.
|
||||
|
||||
@@ -176,7 +176,9 @@ class ClientSupport(WrapperCpp):
|
||||
f"public_result must be of type PublicResult, not {type(public_result)}"
|
||||
)
|
||||
lambda_arg = LambdaArgument.wrap(
|
||||
_ClientSupport.decrypt_result(keyset.cpp(), public_result.cpp())
|
||||
_ClientSupport.decrypt_result(
|
||||
client_parameters.cpp(), keyset.cpp(), public_result.cpp()
|
||||
)
|
||||
)
|
||||
|
||||
output_signs = client_parameters.output_signs()
|
||||
@@ -216,7 +218,6 @@ class ClientSupport(WrapperCpp):
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-return-statements,too-many-branches
|
||||
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError(
|
||||
"value of lambda argument must be either int, numpy.array or numpy.(u)int{8,16,32,64}"
|
||||
|
||||
@@ -39,8 +39,8 @@ class CompilationContext(WrapperCpp):
|
||||
)
|
||||
super().__init__(compilation_context)
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
@staticmethod
|
||||
def new() -> "CompilationContext":
|
||||
"""Build a CompilationContext.
|
||||
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
|
||||
|
||||
"""JITCompilationResult."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JITCompilationResult as _JITCompilationResult,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
|
||||
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class JITCompilationResult(WrapperCpp):
|
||||
"""JITCompilationResult holds the result of a JIT compilation.
|
||||
|
||||
It can be instrumented using the JITSupport to load client parameters and execute the compiled
|
||||
code.
|
||||
"""
|
||||
|
||||
def __init__(self, jit_compilation_result: _JITCompilationResult):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
jit_compilation_result (_JITCompilationResult): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_compilation_result is not of type _JITCompilationResult
|
||||
"""
|
||||
if not isinstance(jit_compilation_result, _JITCompilationResult):
|
||||
raise TypeError(
|
||||
f"jit_compilation_result must be of type _JITCompilationResult, not "
|
||||
f"{type(jit_compilation_result)}"
|
||||
)
|
||||
super().__init__(jit_compilation_result)
|
||||
@@ -1,35 +0,0 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
|
||||
|
||||
"""JITLambda."""
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JITLambda as _JITLambda,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
|
||||
class JITLambda(WrapperCpp):
|
||||
"""JITLambda contains an in-memory executable code and can be ran using JITSupport.
|
||||
|
||||
It's an artifact of JIT compilation, which stays in memory and can be executed with the help of
|
||||
JITSupport.
|
||||
"""
|
||||
|
||||
def __init__(self, jit_lambda: _JITLambda):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
jit_lambda (_JITLambda): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_lambda is not of type JITLambda
|
||||
"""
|
||||
if not isinstance(jit_lambda, _JITLambda):
|
||||
raise TypeError(
|
||||
f"jit_lambda must be of type _JITLambda, not {type(jit_lambda)}"
|
||||
)
|
||||
super().__init__(jit_lambda)
|
||||
@@ -1,202 +0,0 @@
|
||||
# Part of the Concrete Compiler Project, under the BSD3 License with Zama Exceptions.
|
||||
# See https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt for license information.
|
||||
|
||||
"""JITSupport.
|
||||
|
||||
Just-in-time compilation provide a way to compile and execute an MLIR program while keeping the executable
|
||||
code in memory.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
JITSupport as _JITSupport,
|
||||
)
|
||||
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .utils import lookup_runtime_lib
|
||||
from .compilation_options import CompilationOptions
|
||||
from .jit_compilation_result import JITCompilationResult
|
||||
from .client_parameters import ClientParameters
|
||||
from .compilation_feedback import CompilationFeedback
|
||||
from .jit_lambda import JITLambda
|
||||
from .public_arguments import PublicArguments
|
||||
from .public_result import PublicResult
|
||||
from .wrapper import WrapperCpp
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
|
||||
|
||||
class JITSupport(WrapperCpp):
|
||||
"""Support class for JIT compilation and execution."""
|
||||
|
||||
def __init__(self, jit_support: _JITSupport):
|
||||
"""Wrap the native Cpp object.
|
||||
|
||||
Args:
|
||||
jit_support (_JITSupport): object to wrap
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_support is not of type _JITSupport
|
||||
"""
|
||||
if not isinstance(jit_support, _JITSupport):
|
||||
raise TypeError(
|
||||
f"jit_support must be of type _JITSupport, not {type(jit_support)}"
|
||||
)
|
||||
super().__init__(jit_support)
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(runtime_lib_path: Optional[str] = None) -> "JITSupport":
|
||||
"""Build a JITSupport.
|
||||
|
||||
Args:
|
||||
runtime_lib_path (Optional[str]): path to the runtime library. Defaults to None.
|
||||
|
||||
Raises:
|
||||
TypeError: if runtime_lib_path is not of type str or None
|
||||
|
||||
Returns:
|
||||
JITSupport
|
||||
"""
|
||||
if runtime_lib_path is None:
|
||||
runtime_lib_path = lookup_runtime_lib()
|
||||
else:
|
||||
if not isinstance(runtime_lib_path, str):
|
||||
raise TypeError(
|
||||
f"runtime_lib_path must be of type str, not {type(runtime_lib_path)}"
|
||||
)
|
||||
return JITSupport.wrap(_JITSupport(runtime_lib_path))
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
|
||||
def compile(
|
||||
self,
|
||||
mlir_program: str,
|
||||
options: CompilationOptions = CompilationOptions.new("main"),
|
||||
) -> JITCompilationResult:
|
||||
"""JIT compile an MLIR program using Concrete dialects.
|
||||
|
||||
Args:
|
||||
mlir_program (str): textual representation of the mlir program to compile
|
||||
options (CompilationOptions): compilation options
|
||||
|
||||
Raises:
|
||||
TypeError: if mlir_program is not of type str
|
||||
TypeError: if options is not of type CompilationOptions
|
||||
|
||||
Returns:
|
||||
JITCompilationResult: the result of the JIT compilation
|
||||
"""
|
||||
if not isinstance(mlir_program, str):
|
||||
raise TypeError(
|
||||
f"mlir_program must be of type str, not {type(mlir_program)}"
|
||||
)
|
||||
if not isinstance(options, CompilationOptions):
|
||||
raise TypeError(
|
||||
f"options must be of type CompilationOptions, not {type(options)}"
|
||||
)
|
||||
return JITCompilationResult.wrap(
|
||||
self.cpp().compile(mlir_program, options.cpp())
|
||||
)
|
||||
|
||||
def load_client_parameters(
|
||||
self, compilation_result: JITCompilationResult
|
||||
) -> ClientParameters:
|
||||
"""Load the client parameters from the JIT compilation result.
|
||||
|
||||
Args:
|
||||
compilation_result (JITCompilationResult): result of the JIT compilation
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_result is not of type JITCompilationResult
|
||||
|
||||
Returns:
|
||||
ClientParameters: appropriate client parameters for the compiled program
|
||||
"""
|
||||
if not isinstance(compilation_result, JITCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
|
||||
)
|
||||
return ClientParameters.wrap(
|
||||
self.cpp().load_client_parameters(compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_compilation_feedback(
|
||||
self, compilation_result: JITCompilationResult
|
||||
) -> CompilationFeedback:
|
||||
"""Load the compilation feedback from the JIT compilation result.
|
||||
|
||||
Args:
|
||||
compilation_result (JITCompilationResult): result of the JIT compilation
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_result is not of type JITCompilationResult
|
||||
|
||||
Returns:
|
||||
CompilationFeedback: the compilation feedback for the compiled program
|
||||
"""
|
||||
if not isinstance(compilation_result, JITCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
|
||||
)
|
||||
return CompilationFeedback.wrap(
|
||||
self.cpp().load_compilation_feedback(compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_server_lambda(self, compilation_result: JITCompilationResult) -> JITLambda:
|
||||
"""Load the JITLambda from the JIT compilation result.
|
||||
|
||||
Args:
|
||||
compilation_result (JITCompilationResult): result of the JIT compilation.
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_result is not of type JITCompilationResult
|
||||
|
||||
Returns:
|
||||
JITLambda: loaded JITLambda to be executed
|
||||
"""
|
||||
if not isinstance(compilation_result, JITCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be a JITCompilationResult not {type(compilation_result)}"
|
||||
)
|
||||
return JITLambda.wrap(self.cpp().load_server_lambda(compilation_result.cpp()))
|
||||
|
||||
def server_call(
|
||||
self,
|
||||
jit_lambda: JITLambda,
|
||||
public_arguments: PublicArguments,
|
||||
evaluation_keys: EvaluationKeys,
|
||||
) -> PublicResult:
|
||||
"""Call the JITLambda with public_arguments.
|
||||
|
||||
Args:
|
||||
jit_lambda (JITLambda): A server lambda to call.
|
||||
public_arguments (PublicArguments): The arguments of the call.
|
||||
evaluation_keys (EvaluationKeys): Evalutation keys of the call.
|
||||
|
||||
Raises:
|
||||
TypeError: if jit_lambda is not of type JITLambda
|
||||
TypeError: if public_arguments is not of type PublicArguments
|
||||
TypeError: if evaluation_keys is not of type EvaluationKeys
|
||||
|
||||
Returns:
|
||||
PublicResult: the result of the call of the server lambda.
|
||||
"""
|
||||
if not isinstance(jit_lambda, JITLambda):
|
||||
raise TypeError(
|
||||
f"jit_lambda must be of type JITLambda, not {type(jit_lambda)}"
|
||||
)
|
||||
if not isinstance(public_arguments, PublicArguments):
|
||||
raise TypeError(
|
||||
f"public_arguments must be of type PublicArguments, not {type(public_arguments)}"
|
||||
)
|
||||
if not isinstance(evaluation_keys, EvaluationKeys):
|
||||
raise TypeError(
|
||||
f"evaluation_keys must be of type EvaluationKeys, not {type(evaluation_keys)}"
|
||||
)
|
||||
return PublicResult.wrap(
|
||||
self.cpp().server_call(
|
||||
jit_lambda.cpp(), public_arguments.cpp(), evaluation_keys.cpp()
|
||||
)
|
||||
)
|
||||
@@ -15,7 +15,6 @@ from mlir._mlir_libs._concretelang._compiler import (
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from .wrapper import WrapperCpp
|
||||
from .evaluation_keys import EvaluationKeys
|
||||
from .client_parameters import ClientParameters
|
||||
|
||||
|
||||
class KeySet(WrapperCpp):
|
||||
@@ -64,10 +63,6 @@ class KeySet(WrapperCpp):
|
||||
)
|
||||
return KeySet.wrap(_KeySet.deserialize(serialized_key_set))
|
||||
|
||||
def client_parameters(self) -> ClientParameters:
|
||||
"""Get client parameters of keyset."""
|
||||
return self.cpp().client_parameters()
|
||||
|
||||
def get_evaluation_keys(self) -> EvaluationKeys:
|
||||
"""
|
||||
Get evaluation keys for execution.
|
||||
|
||||
@@ -124,6 +124,8 @@ class LibrarySupport(WrapperCpp):
|
||||
generateCppHeader,
|
||||
)
|
||||
)
|
||||
if not os.path.isdir(output_path):
|
||||
os.makedirs(output_path)
|
||||
library_support.output_dir_path = output_path
|
||||
return library_support
|
||||
|
||||
@@ -216,27 +218,29 @@ class LibrarySupport(WrapperCpp):
|
||||
def load_compilation_feedback(
|
||||
self, compilation_result: LibraryCompilationResult
|
||||
) -> CompilationFeedback:
|
||||
"""Load the compilation feedback from the JIT compilation result.
|
||||
"""Load the compilation feedback from the compilation result.
|
||||
|
||||
Args:
|
||||
compilation_result (JITCompilationResult): result of the JIT compilation
|
||||
compilation_result (LibraryCompilationResult): result of the compilation
|
||||
|
||||
Raises:
|
||||
TypeError: if compilation_result is not of type JITCompilationResult
|
||||
TypeError: if compilation_result is not of type LibraryCompilationResult
|
||||
|
||||
Returns:
|
||||
CompilationFeedback: the compilation feedback for the compiled program
|
||||
"""
|
||||
if not isinstance(compilation_result, LibraryCompilationResult):
|
||||
raise TypeError(
|
||||
f"compilation_result must be of type JITCompilationResult, not {type(compilation_result)}"
|
||||
f"compilation_result must be of type LibraryCompilationResult, not {type(compilation_result)}"
|
||||
)
|
||||
return CompilationFeedback.wrap(
|
||||
self.cpp().load_compilation_feedback(compilation_result.cpp())
|
||||
)
|
||||
|
||||
def load_server_lambda(
|
||||
self, library_compilation_result: LibraryCompilationResult
|
||||
self,
|
||||
library_compilation_result: LibraryCompilationResult,
|
||||
simulation: bool,
|
||||
) -> LibraryLambda:
|
||||
"""Load the server lambda from the library compilation result.
|
||||
|
||||
@@ -255,7 +259,7 @@ class LibrarySupport(WrapperCpp):
|
||||
f"{type(library_compilation_result)}"
|
||||
)
|
||||
return LibraryLambda.wrap(
|
||||
self.cpp().load_server_lambda(library_compilation_result.cpp())
|
||||
self.cpp().load_server_lambda(library_compilation_result.cpp(), simulation)
|
||||
)
|
||||
|
||||
def server_call(
|
||||
@@ -340,10 +344,10 @@ class LibrarySupport(WrapperCpp):
|
||||
"""
|
||||
return self.cpp().get_shared_lib_path()
|
||||
|
||||
def get_client_parameters_path(self) -> str:
|
||||
"""Get the path where the client parameters file is expected to be.
|
||||
def get_program_info_path(self) -> str:
|
||||
"""Get the path where the program info file is expected to be.
|
||||
|
||||
Returns:
|
||||
str: path to the client parameters file
|
||||
str: path to the program info file
|
||||
"""
|
||||
return self.cpp().get_client_parameters_path()
|
||||
return self.cpp().get_program_info_path()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
@@ -65,47 +65,16 @@ class SimulatedValueDecrypter(WrapperCpp):
|
||||
decrypted value
|
||||
"""
|
||||
|
||||
shape = tuple(self.cpp().get_shape(position))
|
||||
lambda_arg = self.cpp().decrypt(position, value.cpp())
|
||||
is_signed = lambda_arg.is_signed()
|
||||
if lambda_arg.is_scalar():
|
||||
return (
|
||||
lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar()
|
||||
)
|
||||
|
||||
if len(shape) == 0:
|
||||
return self.decrypt_scalar(position, value)
|
||||
|
||||
return np.array(self.decrypt_tensor(position, value), dtype=np.int64).reshape(
|
||||
shape
|
||||
shape = lambda_arg.get_tensor_shape()
|
||||
return (
|
||||
np.array(lambda_arg.get_signed_tensor_data()).reshape(shape)
|
||||
if is_signed
|
||||
else np.array(lambda_arg.get_tensor_data()).reshape(shape)
|
||||
)
|
||||
|
||||
def decrypt_scalar(self, position: int, value: Value) -> int:
|
||||
"""
|
||||
Decrypt scalar.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
value (Value):
|
||||
scalar value to decrypt
|
||||
|
||||
Returns:
|
||||
int:
|
||||
decrypted scalar
|
||||
"""
|
||||
|
||||
return self.cpp().decrypt_scalar(position, value.cpp())
|
||||
|
||||
def decrypt_tensor(self, position: int, value: Value) -> List[int]:
|
||||
"""
|
||||
Decrypt tensor.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
value (Value):
|
||||
tensor value to decrypt
|
||||
|
||||
Returns:
|
||||
List[int]:
|
||||
decrypted tensor
|
||||
"""
|
||||
|
||||
return self.cpp().decrypt_tensor(position, value.cpp())
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
@@ -66,47 +66,16 @@ class ValueDecrypter(WrapperCpp):
|
||||
decrypted value
|
||||
"""
|
||||
|
||||
shape = tuple(self.cpp().get_shape(position))
|
||||
lambda_arg = self.cpp().decrypt(position, value.cpp())
|
||||
is_signed = lambda_arg.is_signed()
|
||||
if lambda_arg.is_scalar():
|
||||
return (
|
||||
lambda_arg.get_signed_scalar() if is_signed else lambda_arg.get_scalar()
|
||||
)
|
||||
|
||||
if len(shape) == 0:
|
||||
return self.decrypt_scalar(position, value)
|
||||
|
||||
return np.array(self.decrypt_tensor(position, value), dtype=np.int64).reshape(
|
||||
shape
|
||||
shape = lambda_arg.get_tensor_shape()
|
||||
return (
|
||||
np.array(lambda_arg.get_signed_tensor_data()).reshape(shape)
|
||||
if is_signed
|
||||
else np.array(lambda_arg.get_tensor_data()).reshape(shape)
|
||||
)
|
||||
|
||||
def decrypt_scalar(self, position: int, value: Value) -> int:
|
||||
"""
|
||||
Decrypt scalar.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
value (Value):
|
||||
scalar value to decrypt
|
||||
|
||||
Returns:
|
||||
int:
|
||||
decrypted scalar
|
||||
"""
|
||||
|
||||
return self.cpp().decrypt_scalar(position, value.cpp())
|
||||
|
||||
def decrypt_tensor(self, position: int, value: Value) -> List[int]:
|
||||
"""
|
||||
Decrypt tensor.
|
||||
|
||||
Args:
|
||||
position (int):
|
||||
position of the argument within the circuit
|
||||
|
||||
value (Value):
|
||||
tensor value to decrypt
|
||||
|
||||
Returns:
|
||||
List[int]:
|
||||
decrypted tensor
|
||||
"""
|
||||
|
||||
return self.cpp().decrypt_tensor(position, value.cpp())
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
[package]
|
||||
name = "concrete-compiler"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[build-dependencies]
|
||||
bindgen = "0.60.1"
|
||||
|
||||
[dev-dependencies]
|
||||
tempdir = "0.3.7"
|
||||
@@ -1,12 +0,0 @@
|
||||
# Rust Bindings
|
||||
|
||||
A Rust library providing an API to the Concrete Compiler.
|
||||
|
||||
### Build
|
||||
|
||||
Set `CONCRETE_COMPILER_INSTALL_DIR` to the right path before building with `cargo`
|
||||
|
||||
```bash
|
||||
$ export CONCRETE_COMPILER_INSTALL_DIR=/installation/path/concretecompiler/
|
||||
$ cargo build --release
|
||||
```
|
||||
@@ -1,35 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <concretelang-c/Dialect/FHE.h>
|
||||
#include <concretelang-c/Dialect/FHELinalg.h>
|
||||
#include <concretelang-c/Support/CompilerEngine.h>
|
||||
#include <mlir-c/AffineExpr.h>
|
||||
#include <mlir-c/AffineMap.h>
|
||||
#include <mlir-c/BuiltinAttributes.h>
|
||||
#include <mlir-c/BuiltinTypes.h>
|
||||
#include <mlir-c/Conversion.h>
|
||||
#include <mlir-c/Debug.h>
|
||||
#include <mlir-c/Diagnostics.h>
|
||||
#include <mlir-c/Dialect/Async.h>
|
||||
#include <mlir-c/Dialect/ControlFlow.h>
|
||||
#include <mlir-c/Dialect/Func.h>
|
||||
#include <mlir-c/Dialect/GPU.h>
|
||||
#include <mlir-c/Dialect/LLVM.h>
|
||||
#include <mlir-c/Dialect/Linalg.h>
|
||||
#include <mlir-c/Dialect/PDL.h>
|
||||
#include <mlir-c/Dialect/Quant.h>
|
||||
#include <mlir-c/Dialect/SCF.h>
|
||||
#include <mlir-c/Dialect/Shape.h>
|
||||
#include <mlir-c/Dialect/SparseTensor.h>
|
||||
#include <mlir-c/Dialect/Tensor.h>
|
||||
#include <mlir-c/ExecutionEngine.h>
|
||||
#include <mlir-c/IR.h>
|
||||
#include <mlir-c/IntegerSet.h>
|
||||
#include <mlir-c/Interfaces.h>
|
||||
#include <mlir-c/Pass.h>
|
||||
#include <mlir-c/RegisterEverything.h>
|
||||
#include <mlir-c/Support.h>
|
||||
#include <mlir-c/Transforms.h>
|
||||
@@ -1,378 +0,0 @@
|
||||
extern crate bindgen;
|
||||
|
||||
use std::env;
|
||||
use std::error::Error;
|
||||
use std::path::Path;
|
||||
use std::process::exit;
|
||||
|
||||
const MLIR_STATIC_LIBS: &[&str] = &[
|
||||
"MLIRArithAttrToLLVMConversion",
|
||||
"MLIRDestinationStyleOpInterface",
|
||||
"MLIRVectorTransformOps",
|
||||
"MLIRMemRefTransformOps",
|
||||
"MLIRGPUTransformOps",
|
||||
"MLIRAffineTransformOps",
|
||||
"MLIRBytecodeReader",
|
||||
"MLIRAsmParser",
|
||||
"MLIRIndexDialect",
|
||||
"MLIRMaskableOpInterface",
|
||||
"MLIRMaskingOpInterface",
|
||||
"MLIRInferIntRangeCommon",
|
||||
"MLIRShapedOpInterfaces",
|
||||
"MLIRTransformDialectUtils",
|
||||
"MLIRParallelCombiningOpInterface",
|
||||
"MLIRMemRefDialect",
|
||||
"MLIRVectorToSPIRV",
|
||||
"MLIRControlFlowInterfaces",
|
||||
"MLIRLinalgToStandard",
|
||||
"MLIRAnalysis",
|
||||
"MLIRSPIRVDeserialization",
|
||||
"MLIRTransformDialect",
|
||||
"MLIRSparseTensorPipelines",
|
||||
"MLIRVectorToGPU",
|
||||
"MLIRTranslateLib",
|
||||
"MLIRPass",
|
||||
"MLIRComplexToLibm",
|
||||
"MLIRInferTypeOpInterface",
|
||||
"MLIRMemRefToSPIRV",
|
||||
"MLIRAMDGPUToROCDL",
|
||||
"MLIRBufferizationTransformOps",
|
||||
"MLIRExecutionEngineUtils",
|
||||
"MLIRNVVMDialect",
|
||||
"MLIRSCFUtils",
|
||||
"MLIRLinalgTransforms",
|
||||
"MLIRParser",
|
||||
"MLIRFuncTransforms",
|
||||
"MLIRTosaTestPasses",
|
||||
"MLIRTosaToArith",
|
||||
"MLIRTensorDialect",
|
||||
"MLIRGPUTransforms",
|
||||
"MLIRLowerableDialectsToLLVM",
|
||||
"MLIRBufferizationToMemRef",
|
||||
"MLIRPresburger",
|
||||
"MLIRFuncDialect",
|
||||
"MLIRPDLToPDLInterp",
|
||||
"MLIRArithTransforms",
|
||||
"MLIRViewLikeInterface",
|
||||
"MLIRTargetCpp",
|
||||
"MLIROpenMPToLLVM",
|
||||
"MLIRSPIRVConversion",
|
||||
"MLIRNVGPUTransforms",
|
||||
"MLIRSparseTensorTransforms",
|
||||
"MLIRAffineAnalysis",
|
||||
"MLIRArmSVETransforms",
|
||||
"MLIRArmNeon2dToIntr",
|
||||
"MLIRDataLayoutInterfaces",
|
||||
"MLIRAffineTransforms",
|
||||
"MLIROpenACCToLLVMIRTranslation",
|
||||
"MLIRTensorUtils",
|
||||
"MLIRSPIRVSerialization",
|
||||
"MLIRShapeToStandard",
|
||||
"MLIRArithToSPIRV",
|
||||
"MLIRArithDialect",
|
||||
"MLIRFuncToSPIRV",
|
||||
"MLIRQuantUtils",
|
||||
"MLIRTensorTilingInterfaceImpl",
|
||||
"MLIRX86VectorToLLVMIRTranslation",
|
||||
"MLIRCopyOpInterface",
|
||||
"MLIRMathToLibm",
|
||||
"MLIRGPUToGPURuntimeTransforms",
|
||||
"MLIRLLVMDialect",
|
||||
"MLIRAffineDialect",
|
||||
"MLIRTransforms",
|
||||
"MLIRVectorTransforms",
|
||||
"MLIROpenMPDialect",
|
||||
"MLIRControlFlowDialect",
|
||||
"MLIRVectorUtils",
|
||||
"MLIRROCDLDialect",
|
||||
"MLIRPDLDialect",
|
||||
"MLIRAsyncDialect",
|
||||
"MLIRLinalgToLLVM",
|
||||
"MLIROpenACCDialect",
|
||||
"MLIRVectorDialect",
|
||||
"MLIROpenACCToSCF",
|
||||
"MLIRIR",
|
||||
"MLIRCAPIIR",
|
||||
"MLIRTargetLLVMIRImport",
|
||||
"MLIRTensorToLinalg",
|
||||
"MLIRCallInterfaces",
|
||||
"MLIRTensorInferTypeOpInterfaceImpl",
|
||||
"MLIRTransformDialectTransforms",
|
||||
"MLIRComplexDialect",
|
||||
"MLIRAffineUtils",
|
||||
"MLIRLoopLikeInterface",
|
||||
"MLIRDialect",
|
||||
"MLIRLinalgUtils",
|
||||
"MLIRSCFToSPIRV",
|
||||
"MLIRAffineToStandard",
|
||||
"MLIRX86VectorDialect",
|
||||
"MLIRGPUToVulkanTransforms",
|
||||
"MLIRRewrite",
|
||||
"MLIRAMXToLLVMIRTranslation",
|
||||
"MLIRInferIntRangeInterface",
|
||||
"MLIRCAPIRegisterEverything",
|
||||
"MLIRNVVMToLLVMIRTranslation",
|
||||
"MLIRAsyncTransforms",
|
||||
"MLIRPDLInterpDialect",
|
||||
"MLIRTransformUtils",
|
||||
"MLIRLinalgDialect",
|
||||
"MLIRMathDialect",
|
||||
"MLIRMemRefTransforms",
|
||||
"MLIRSPIRVModuleCombiner",
|
||||
"MLIRMathToLLVM",
|
||||
"MLIRControlFlowToLLVM",
|
||||
"MLIRArmSVEDialect",
|
||||
"MLIRSPIRVTranslateRegistration",
|
||||
"MLIRToLLVMIRTranslationRegistration",
|
||||
"MLIRSCFDialect",
|
||||
"MLIRTilingInterface",
|
||||
"MLIREmitCDialect",
|
||||
"MLIRTableGen",
|
||||
"MLIRTosaToSCF",
|
||||
"MLIROpenMPToLLVMIRTranslation",
|
||||
"MLIRSupport",
|
||||
"MLIROpenACCToLLVM",
|
||||
"MLIRAMDGPUDialect",
|
||||
"MLIRTosaToLinalg",
|
||||
"MLIRSparseTensorUtils",
|
||||
"MLIRFuncToLLVM",
|
||||
"MLIRTargetLLVMIRExport",
|
||||
"MLIRControlFlowToSPIRV",
|
||||
"MLIRReconcileUnrealizedCasts",
|
||||
"MLIRComplexToStandard",
|
||||
"MLIRMathTransforms",
|
||||
"MLIRSPIRVUtils",
|
||||
"MLIRCastInterfaces",
|
||||
"MLIRTosaToTensor",
|
||||
"MLIRGPUToSPIRV",
|
||||
"MLIRBufferizationDialect",
|
||||
"MLIRSCFToControlFlow",
|
||||
"MLIRArmSVEToLLVMIRTranslation",
|
||||
"MLIRExecutionEngine",
|
||||
"MLIRBufferizationTransforms",
|
||||
"MLIRSparseTensorDialect",
|
||||
"MLIRTensorToSPIRV",
|
||||
"MLIRVectorToSCF",
|
||||
"MLIRLLVMToLLVMIRTranslation",
|
||||
"MLIRNVGPUDialect",
|
||||
"MLIRAsyncToLLVM",
|
||||
"MLIRAMXDialect",
|
||||
"MLIRLinalgTransformOps",
|
||||
"MLIRMathToSPIRV",
|
||||
"MLIRSCFToOpenMP",
|
||||
"MLIRShapeDialect",
|
||||
"MLIRGPUToROCDLTransforms",
|
||||
"MLIRGPUToNVVMTransforms",
|
||||
"MLIRTensorTransforms",
|
||||
"MLIRSCFToGPU",
|
||||
"MLIRDialectUtils",
|
||||
"MLIRNVGPUToNVVM",
|
||||
"MLIRTosaDialect",
|
||||
"MLIRVectorToLLVM",
|
||||
"MLIRSPIRVDialect",
|
||||
"MLIRSideEffectInterfaces",
|
||||
"MLIRQuantDialect",
|
||||
"MLIRSCFTransforms",
|
||||
"MLIRMLProgramDialect",
|
||||
"MLIRDLTIDialect",
|
||||
"MLIRLinalgFrontend",
|
||||
"MLIRROCDLToLLVMIRTranslation",
|
||||
"MLIRArmNeonDialect",
|
||||
"MLIRSPIRVToLLVM",
|
||||
"MLIRLLVMIRTransforms",
|
||||
"MLIRTosaTransforms",
|
||||
"MLIRLLVMCommonConversion",
|
||||
"MLIRSCFTransformOps",
|
||||
"MLIRArmNeonToLLVMIRTranslation",
|
||||
"MLIRAMXTransforms",
|
||||
"MLIRSPIRVTransforms",
|
||||
"MLIRMemRefToLLVM",
|
||||
"MLIRSPIRVBinaryUtils",
|
||||
"MLIRArithUtils",
|
||||
"MLIRVectorInterfaces",
|
||||
"MLIRGPUOps",
|
||||
"MLIRComplexToLLVM",
|
||||
"MLIRShapeOpsTransforms",
|
||||
"MLIRX86VectorTransforms",
|
||||
"MLIRArithToLLVM",
|
||||
];
|
||||
|
||||
const LLVM_STATIC_LIBS: &[&str] = &[
|
||||
"LLVMAggressiveInstCombine",
|
||||
"LLVMAnalysis",
|
||||
"LLVMAsmParser",
|
||||
"LLVMAsmPrinter",
|
||||
"LLVMBinaryFormat",
|
||||
"LLVMBitReader",
|
||||
"LLVMBitstreamReader",
|
||||
"LLVMBitWriter",
|
||||
"LLVMCFGuard",
|
||||
"LLVMCodeGen",
|
||||
"LLVMCore",
|
||||
"LLVMCoroutines",
|
||||
"LLVMDebugInfoCodeView",
|
||||
"LLVMDebugInfoDWARF",
|
||||
"LLVMDebugInfoMSF",
|
||||
"LLVMDebugInfoPDB",
|
||||
"LLVMDemangle",
|
||||
"LLVMExecutionEngine",
|
||||
"LLVMFrontendOpenMP",
|
||||
"LLVMGlobalISel",
|
||||
"LLVMInstCombine",
|
||||
"LLVMInstrumentation",
|
||||
"LLVMipo",
|
||||
"LLVMIRReader",
|
||||
"LLVMJITLink",
|
||||
"LLVMLinker",
|
||||
"LLVMMC",
|
||||
"LLVMMCDisassembler",
|
||||
"LLVMMCParser",
|
||||
"LLVMObjCARCOpts",
|
||||
"LLVMObject",
|
||||
"LLVMOption",
|
||||
"LLVMOrcJIT",
|
||||
"LLVMOrcShared",
|
||||
"LLVMOrcTargetProcess",
|
||||
"LLVMPasses",
|
||||
"LLVMProfileData",
|
||||
"LLVMRemarks",
|
||||
"LLVMRuntimeDyld",
|
||||
"LLVMScalarOpts",
|
||||
"LLVMSelectionDAG",
|
||||
"LLVMSupport",
|
||||
"LLVMSymbolize",
|
||||
"LLVMTableGen",
|
||||
"LLVMTableGenGlobalISel",
|
||||
"LLVMTarget",
|
||||
"LLVMTargetParser",
|
||||
"LLVMTextAPI",
|
||||
"LLVMTransformUtils",
|
||||
"LLVMVectorize",
|
||||
];
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &[
|
||||
"LLVMAArch64Utils",
|
||||
"LLVMAArch64Info",
|
||||
"LLVMAArch64Desc",
|
||||
"LLVMAArch64CodeGen",
|
||||
];
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
const LLVM_TARGET_SPECIFIC_STATIC_LIBS: &[&str] = &["LLVMX86CodeGen", "LLVMX86Desc", "LLVMX86Info"];
|
||||
|
||||
const CONCRETE_COMPILER_STATIC_LIBS: &[&str] = &[
|
||||
"AnalysisUtils",
|
||||
"RTDialect",
|
||||
"RTDialectTransforms",
|
||||
"ConcretelangSupport",
|
||||
"ConcreteToCAPI",
|
||||
"ConcretelangConversion",
|
||||
"ConcretelangTransforms",
|
||||
"FHETensorOpsToLinalg",
|
||||
"ConcretelangServerLib",
|
||||
"CONCRETELANGCAPIFHE",
|
||||
"TFHEGlobalParametrization",
|
||||
"ConcretelangClientLib",
|
||||
"ConcretelangConcreteTransforms",
|
||||
"ConcretelangSDFGInterfaces",
|
||||
"ConcretelangSDFGTransforms",
|
||||
"CONCRETELANGCAPISupport",
|
||||
"FHELinalgDialect",
|
||||
"TracingDialect",
|
||||
"TracingDialectTransforms",
|
||||
"TracingToCAPI",
|
||||
"ConcretelangInterfaces",
|
||||
"TFHEDialect",
|
||||
"SimulateTFHE",
|
||||
"CONCRETELANGCAPIFHELINALG",
|
||||
"FHELinalgDialectTransforms",
|
||||
"FHEDialect",
|
||||
"FHEInterfaces",
|
||||
"FHEDialectTransforms",
|
||||
"TFHEToConcrete",
|
||||
"FHEToTFHECrt",
|
||||
"FHEToTFHEScalar",
|
||||
"TFHEDialectTransforms",
|
||||
"TFHEKeyNormalization",
|
||||
"concrete_optimizer",
|
||||
"LinalgExtras",
|
||||
"FHEDialectAnalysis",
|
||||
"ConcreteDialect",
|
||||
"RTDialectAnalysis",
|
||||
"SDFGDialect",
|
||||
"ExtractSDFGOps",
|
||||
"SDFGToStreamEmulator",
|
||||
"TFHEDialectAnalysis",
|
||||
"ConcreteDialectAnalysis",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
if let Err(error) = run() {
|
||||
eprintln!("{}", error);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn run() -> Result<(), Box<dyn Error>> {
|
||||
let mut include_paths = Vec::new();
|
||||
// if set, use installation path of concrete compiler to lookup libraries and include files
|
||||
match env::var("CONCRETE_COMPILER_INSTALL_DIR") {
|
||||
Ok(install_dir) => {
|
||||
println!("cargo:rustc-link-search={}/lib/", install_dir);
|
||||
include_paths.push(Path::new(&format!("{}/include/", install_dir)).to_path_buf());
|
||||
}
|
||||
Err(_e) => println!(
|
||||
"cargo:warning=You are not setting CONCRETE_COMPILER_INSTALL_DIR, \
|
||||
so your compiler/linker will have to lookup libs and include dirs on their own"
|
||||
),
|
||||
}
|
||||
// linking to static libs
|
||||
let all_static_libs = CONCRETE_COMPILER_STATIC_LIBS
|
||||
.into_iter()
|
||||
.chain(MLIR_STATIC_LIBS)
|
||||
.chain(LLVM_STATIC_LIBS)
|
||||
.chain(LLVM_TARGET_SPECIFIC_STATIC_LIBS);
|
||||
|
||||
for static_lib_name in all_static_libs {
|
||||
println!("cargo:rustc-link-lib=static={}", static_lib_name);
|
||||
}
|
||||
// concrete compiler runtime
|
||||
println!("cargo:rustc-link-lib=ConcretelangRuntime");
|
||||
// concrete optimizer
|
||||
// `-bundle` serve to not have multiple definition issues
|
||||
println!("cargo:rustc-link-lib=static:-bundle=concrete_optimizer_cpp");
|
||||
|
||||
// required by llvm
|
||||
println!("cargo:rustc-link-lib=ncurses");
|
||||
if let Some(name) = get_system_libcpp() {
|
||||
println!("cargo:rustc-link-lib={}", name);
|
||||
}
|
||||
// zlib
|
||||
println!("cargo:rustc-link-lib=z");
|
||||
|
||||
println!("cargo:rerun-if-changed=api.h");
|
||||
bindgen::builder()
|
||||
.header("api.h")
|
||||
.clang_args(
|
||||
include_paths
|
||||
.into_iter()
|
||||
.map(|path| format!("-I{}", path.to_str().unwrap())),
|
||||
)
|
||||
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
|
||||
.generate()
|
||||
.unwrap()
|
||||
.write_to_file(Path::new(&env::var("OUT_DIR")?).join("bindings.rs"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_system_libcpp() -> Option<&'static str> {
|
||||
if cfg!(target_env = "msvc") {
|
||||
None
|
||||
} else if cfg!(target_os = "macos") {
|
||||
Some("c++")
|
||||
} else {
|
||||
Some("stdc++")
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,768 +0,0 @@
|
||||
//! FHE dialect module
|
||||
|
||||
use crate::mlir::ffi::*;
|
||||
use crate::mlir::*;
|
||||
|
||||
pub fn create_fhe_add_eint_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.add_eint",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_add_eint_int_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.add_eint_int",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_sub_eint_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.sub_eint",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_sub_eint_int_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.sub_eint_int",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_sub_int_eint_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(rhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.sub_int_eint",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_negate_eint_op(context: MlirContext, eint: MlirValue) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(eint)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.neg_eint",
|
||||
&[eint],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_mul_eint_int_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.mul_eint_int",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_apply_lut_op(
|
||||
context: MlirContext,
|
||||
eint: MlirValue,
|
||||
lut: MlirValue,
|
||||
result_type: MlirType,
|
||||
) -> MlirOperation {
|
||||
create_op(
|
||||
context,
|
||||
"FHE.apply_lookup_table",
|
||||
&[eint, lut],
|
||||
[result_type].as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FHEError {
|
||||
InvalidFHEType,
|
||||
InvalidWidth,
|
||||
}
|
||||
|
||||
pub fn convert_eint_to_esint_type(
|
||||
context: MlirContext,
|
||||
eint_type: MlirType,
|
||||
) -> Result<MlirType, FHEError> {
|
||||
unsafe {
|
||||
let width = fheTypeIntegerWidthGet(eint_type);
|
||||
if width == 0 {
|
||||
return Err(FHEError::InvalidFHEType);
|
||||
}
|
||||
let type_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, width);
|
||||
if type_or_error.isError {
|
||||
Err(FHEError::InvalidWidth)
|
||||
} else {
|
||||
Ok(type_or_error.type_)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_esint_to_eint_type(
|
||||
context: MlirContext,
|
||||
esint_type: MlirType,
|
||||
) -> Result<MlirType, FHEError> {
|
||||
unsafe {
|
||||
let width = fheTypeIntegerWidthGet(esint_type);
|
||||
if width == 0 {
|
||||
return Err(FHEError::InvalidFHEType);
|
||||
}
|
||||
let type_or_error = fheEncryptedIntegerTypeGetChecked(context, width);
|
||||
if type_or_error.isError {
|
||||
Err(FHEError::InvalidWidth)
|
||||
} else {
|
||||
Ok(type_or_error.type_)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [convert_eint_to_esint_type(context, mlirValueGetType(eint)).unwrap()];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.to_signed",
|
||||
&[eint],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_to_unsigned_op(context: MlirContext, esint: MlirValue) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [convert_esint_to_eint_type(context, mlirValueGetType(esint)).unwrap()];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.to_unsigned",
|
||||
&[esint],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_zero_eint_op(context: MlirContext, result_type: MlirType) -> MlirOperation {
|
||||
create_op(
|
||||
context,
|
||||
"FHE.zero",
|
||||
&[],
|
||||
[result_type].as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_fhe_zero_eint_tensor_op(
|
||||
context: MlirContext,
|
||||
result_type: MlirType,
|
||||
) -> MlirOperation {
|
||||
create_op(
|
||||
context,
|
||||
"FHE.zero_tensor",
|
||||
&[],
|
||||
[result_type].as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_invalid_fhe_eint_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
let invalid_eint = fheEncryptedIntegerTypeGetChecked(context, 0);
|
||||
assert!(invalid_eint.isError);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_fhe_eint_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint = eint_or_error.type_;
|
||||
assert!(fheTypeIsAnEncryptedIntegerType(eint));
|
||||
assert!(!fheTypeIsAnEncryptedSignedIntegerType(eint));
|
||||
assert_eq!(fheTypeIntegerWidthGet(eint), 5);
|
||||
let printed_eint = super::print_mlir_type_to_string(eint);
|
||||
let expected_eint = "!FHE.eint<5>";
|
||||
assert_eq!(printed_eint, expected_eint);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_fhe_esint_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 5);
|
||||
assert!(!esint_or_error.isError);
|
||||
let esint = esint_or_error.type_;
|
||||
assert!(fheTypeIsAnEncryptedSignedIntegerType(esint));
|
||||
assert!(!fheTypeIsAnEncryptedIntegerType(esint));
|
||||
assert_eq!(fheTypeIntegerWidthGet(esint), 5);
|
||||
let printed_esint = super::print_mlir_type_to_string(esint);
|
||||
let expected_esint = "!FHE.esint<5>";
|
||||
assert_eq!(printed_esint, expected_esint);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fhe_func() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 5-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint = eint_or_error.type_;
|
||||
|
||||
// set input/output types of the FHE circuit
|
||||
let func_input_types = [eint, eint];
|
||||
let func_output_types = [eint];
|
||||
|
||||
// create the func operation
|
||||
let func_op = create_func_with_block(
|
||||
context,
|
||||
"main",
|
||||
func_input_types.as_slice(),
|
||||
func_output_types.as_slice(),
|
||||
);
|
||||
let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
|
||||
let func_args = [
|
||||
mlirBlockGetArgument(func_block, 0),
|
||||
mlirBlockGetArgument(func_block, 1),
|
||||
];
|
||||
|
||||
// create an FHE add_eint op and append it to the function block
|
||||
let add_eint_op = create_fhe_add_eint_op(context, func_args[0], func_args[1]);
|
||||
mlirBlockAppendOwnedOperation(func_block, add_eint_op);
|
||||
|
||||
// create ret operation and append it to the block
|
||||
let ret_op = create_ret_op(context, mlirOperationGetResult(add_eint_op, 0));
|
||||
mlirBlockAppendOwnedOperation(func_block, ret_op);
|
||||
|
||||
// create module to hold the previously created function
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
|
||||
|
||||
let printed_module =
|
||||
super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
|
||||
let expected_module = "\
|
||||
module {
|
||||
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
return %0 : !FHE.eint<5>
|
||||
}
|
||||
}
|
||||
";
|
||||
assert_eq!(printed_module, expected_module);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
let printed_op = print_mlir_operation_to_string(zero_op);
|
||||
let expected_op = "%0 = \"FHE.zero\"() : () -> !FHE.eint<6>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_tensor_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 4-bit eint tensor type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 4);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint = eint_or_error.type_;
|
||||
let shape: [i64; 3] = [60, 66, 73];
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let eint_tensor = mlirRankedTensorTypeGetChecked(
|
||||
location,
|
||||
3,
|
||||
shape.as_ptr(),
|
||||
eint,
|
||||
mlirAttributeGetNull(),
|
||||
);
|
||||
|
||||
let zero_op = create_fhe_zero_eint_tensor_op(context, eint_tensor);
|
||||
let printed_op = print_mlir_operation_to_string(zero_op);
|
||||
let expected_op = "%0 = \"FHE.zero_tensor\"() : () -> tensor<60x66x73x!FHE.eint<4>>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// add eint with itself
|
||||
let add_eint_op = create_fhe_add_eint_op(context, eint_value, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, add_eint_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(add_eint_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.add_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an int via a constant op
|
||||
let cst_op = create_constant_int_op(context, 73, 7);
|
||||
mlirBlockAppendOwnedOperation(main_block, cst_op);
|
||||
let int_value = mlirOperationGetResult(cst_op, 0);
|
||||
// add eint int
|
||||
let add_eint_int_op = create_fhe_add_eint_int_op(context, eint_value, int_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, add_eint_int_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(add_eint_int_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.add_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// sub eint with itself
|
||||
let sub_eint_op = create_fhe_sub_eint_op(context, eint_value, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, sub_eint_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(sub_eint_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.sub_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an int via a constant op
|
||||
let cst_op = create_constant_int_op(context, 73, 7);
|
||||
mlirBlockAppendOwnedOperation(main_block, cst_op);
|
||||
let int_value = mlirOperationGetResult(cst_op, 0);
|
||||
// sub eint int
|
||||
let sub_eint_int_op = create_fhe_sub_eint_int_op(context, eint_value, int_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(sub_eint_int_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.sub_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_int_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an int via a constant op
|
||||
let cst_op = create_constant_int_op(context, 73, 7);
|
||||
mlirBlockAppendOwnedOperation(main_block, cst_op);
|
||||
let int_value = mlirOperationGetResult(cst_op, 0);
|
||||
// sub int eint
|
||||
let sub_eint_int_op = create_fhe_sub_int_eint_op(context, int_value, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(sub_eint_int_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.sub_int_eint\"(%c-55_i7, %0) : (i7, !FHE.eint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negate_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// negate eint
|
||||
let neg_eint_op = create_fhe_negate_eint_op(context, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, neg_eint_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(neg_eint_op);
|
||||
let expected_op = "%1 = \"FHE.neg_eint\"(%0) : (!FHE.eint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an int via a constant op
|
||||
let cst_op = create_constant_int_op(context, 73, 7);
|
||||
mlirBlockAppendOwnedOperation(main_block, cst_op);
|
||||
let int_value = mlirOperationGetResult(cst_op, 0);
|
||||
// mul eint int
|
||||
let mul_eint_int_op = create_fhe_mul_eint_int_op(context, eint_value, int_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, mul_eint_int_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(mul_eint_int_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.mul_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_signed_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// to signed
|
||||
let to_signed_op = create_fhe_to_signed_op(context, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, to_signed_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(to_signed_op);
|
||||
let expected_op = "%1 = \"FHE.to_signed\"(%0) : (!FHE.eint<6>) -> !FHE.esint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_unsigned_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit esint type
|
||||
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!esint_or_error.isError);
|
||||
let esint6_type = esint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, esint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let esint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// to unsigned
|
||||
let to_unsigned_op = create_fhe_to_unsigned_op(context, esint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, to_unsigned_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(to_unsigned_op);
|
||||
let expected_op = "%1 = \"FHE.to_unsigned\"(%0) : (!FHE.esint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_lut_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an lut
|
||||
let table: [i64; 64] = [0; 64];
|
||||
let constant_lut_op = create_constant_flat_tensor_op(context, &table, 64);
|
||||
mlirBlockAppendOwnedOperation(main_block, constant_lut_op);
|
||||
let lut = mlirOperationGetResult(constant_lut_op, 0);
|
||||
// LUT op
|
||||
let apply_lut_op = create_fhe_apply_lut_op(context, eint_value, lut, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, apply_lut_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(apply_lut_op);
|
||||
let expected_op = "%1 = \"FHE.apply_lookup_table\"(%0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +0,0 @@
|
||||
pub mod compiler;
|
||||
pub mod fhe;
|
||||
pub mod fhelinalg;
|
||||
pub mod mlir;
|
||||
@@ -1,443 +0,0 @@
|
||||
//! MLIR module
|
||||
|
||||
#![allow(non_upper_case_globals)]
|
||||
#![allow(non_camel_case_types)]
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
pub mod ffi {
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
}
|
||||
|
||||
use ffi::*;
|
||||
use std::ffi::CString;
|
||||
use std::ops::AddAssign;
|
||||
|
||||
pub(crate) unsafe extern "C" fn mlir_rust_string_receiver_callback(
|
||||
mlirStrRef: MlirStringRef,
|
||||
user_data: *mut ::std::os::raw::c_void,
|
||||
) {
|
||||
let rust_string = &mut *(user_data as *mut String);
|
||||
let slc = std::slice::from_raw_parts(mlirStrRef.data as *const u8, mlirStrRef.length as usize);
|
||||
rust_string.add_assign(&String::from_utf8_lossy(slc));
|
||||
}
|
||||
|
||||
pub fn print_mlir_operation_to_string(op: MlirOperation) -> String {
|
||||
let mut rust_string = String::default();
|
||||
let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void;
|
||||
|
||||
unsafe {
|
||||
mlirOperationPrint(op, Some(mlir_rust_string_receiver_callback), receiver_ptr);
|
||||
}
|
||||
|
||||
rust_string
|
||||
}
|
||||
|
||||
pub fn print_mlir_type_to_string(mlir_type: MlirType) -> String {
|
||||
let mut rust_string = String::default();
|
||||
let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void;
|
||||
|
||||
unsafe {
|
||||
mlirTypePrint(
|
||||
mlir_type,
|
||||
Some(mlir_rust_string_receiver_callback),
|
||||
receiver_ptr,
|
||||
);
|
||||
}
|
||||
|
||||
rust_string
|
||||
}
|
||||
|
||||
/// Returns a function operation with a region that contains a block.
|
||||
///
|
||||
/// The function would be defined using the provided input and output types. The main block of the
|
||||
/// function can be later fetched, from which we can get function arguments, and it will be where
|
||||
/// we append operations.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use concrete_compiler::mlir::*;
|
||||
/// use concrete_compiler::mlir::ffi::*;
|
||||
/// unsafe{
|
||||
/// let context = mlirContextCreate();
|
||||
/// register_all_dialects(context);
|
||||
///
|
||||
/// // input/output types
|
||||
/// let func_input_types = [
|
||||
/// mlirIntegerTypeGet(context, 64),
|
||||
/// mlirIntegerTypeGet(context, 64),
|
||||
/// ];
|
||||
/// let func_output_types = [mlirIntegerTypeGet(context, 64)];
|
||||
///
|
||||
/// let func_op = create_func_with_block(
|
||||
/// context,
|
||||
/// "test",
|
||||
/// func_input_types.as_slice(),
|
||||
/// func_output_types.as_slice(),
|
||||
/// );
|
||||
///
|
||||
/// // we can fetch the main block of the function from the function region
|
||||
/// let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
|
||||
/// // we can get arguments to later be used as operands to other operations
|
||||
/// let func_args = [
|
||||
/// mlirBlockGetArgument(func_block, 0),
|
||||
/// mlirBlockGetArgument(func_block, 1),
|
||||
/// ];
|
||||
/// // to add an operation to the function, we will append it to the main block
|
||||
/// let addi_op = create_addi_op(context, func_args[0], func_args[1]);
|
||||
/// mlirBlockAppendOwnedOperation(func_block, addi_op);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
pub fn create_func_with_block(
|
||||
context: MlirContext,
|
||||
func_name: &str,
|
||||
func_input_types: &[MlirType],
|
||||
func_output_types: &[MlirType],
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
// create the main block of the function
|
||||
let locations = (0..func_input_types.len())
|
||||
.into_iter()
|
||||
.map(|_| mlirLocationUnknownGet(context))
|
||||
.collect::<Vec<_>>();
|
||||
let func_block = mlirBlockCreate(
|
||||
func_input_types.len().try_into().unwrap(),
|
||||
func_input_types.as_ptr(),
|
||||
locations.as_ptr(),
|
||||
);
|
||||
|
||||
// create region to hold the previously created block
|
||||
let func_region = mlirRegionCreate();
|
||||
mlirRegionAppendOwnedBlock(func_region, func_block);
|
||||
|
||||
// create function to hold the previously created region
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let func_str = CString::new("func.func").unwrap();
|
||||
let mut func_op_state =
|
||||
mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location);
|
||||
mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr());
|
||||
// set function attributes
|
||||
let func_type_str = CString::new("function_type").unwrap();
|
||||
let sym_name_str = CString::new("sym_name").unwrap();
|
||||
let func_name_str = CString::new(func_name).unwrap();
|
||||
let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||
context,
|
||||
func_input_types.len().try_into().unwrap(),
|
||||
func_input_types.as_ptr(),
|
||||
func_output_types.len().try_into().unwrap(),
|
||||
func_output_types.as_ptr(),
|
||||
));
|
||||
let sym_name_attr = mlirStringAttrGet(
|
||||
context,
|
||||
mlirStringRefCreateFromCString(func_name_str.as_ptr()),
|
||||
);
|
||||
mlirOperationStateAddAttributes(
|
||||
&mut func_op_state,
|
||||
2,
|
||||
[
|
||||
// func type
|
||||
mlirNamedAttributeGet(
|
||||
mlirIdentifierGet(
|
||||
context,
|
||||
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
|
||||
),
|
||||
func_type_attr,
|
||||
),
|
||||
// func name
|
||||
mlirNamedAttributeGet(
|
||||
mlirIdentifierGet(
|
||||
context,
|
||||
mlirStringRefCreateFromCString(sym_name_str.as_ptr()),
|
||||
),
|
||||
sym_name_attr,
|
||||
),
|
||||
]
|
||||
.as_ptr(),
|
||||
);
|
||||
let func_op = mlirOperationCreate(&mut func_op_state);
|
||||
|
||||
func_op
|
||||
}
|
||||
}
|
||||
|
||||
/// Generic function to create an MLIR operation.
|
||||
///
|
||||
/// Create an MLIR operation based on its mnemonic (e.g. addi), it's operands, result types, and
|
||||
/// attributes. Result types can be inferred automatically if the operation itself supports that.
|
||||
pub fn create_op(
|
||||
context: MlirContext,
|
||||
mnemonic: &str,
|
||||
operands: &[MlirValue],
|
||||
results: &[MlirType],
|
||||
attrs: &[MlirNamedAttribute],
|
||||
auto_result_type_inference: bool,
|
||||
) -> MlirOperation {
|
||||
let op_mnemonic = CString::new(mnemonic).unwrap();
|
||||
unsafe {
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let mut op_state = mlirOperationStateGet(
|
||||
mlirStringRefCreateFromCString(op_mnemonic.as_ptr()),
|
||||
location,
|
||||
);
|
||||
mlirOperationStateAddOperands(
|
||||
&mut op_state,
|
||||
operands.len().try_into().unwrap(),
|
||||
operands.as_ptr(),
|
||||
);
|
||||
mlirOperationStateAddAttributes(
|
||||
&mut op_state,
|
||||
attrs.len().try_into().unwrap(),
|
||||
attrs.as_ptr(),
|
||||
);
|
||||
if auto_result_type_inference {
|
||||
mlirOperationStateEnableResultTypeInference(&mut op_state);
|
||||
} else {
|
||||
mlirOperationStateAddResults(
|
||||
&mut op_state,
|
||||
results.len().try_into().unwrap(),
|
||||
results.as_ptr(),
|
||||
);
|
||||
}
|
||||
mlirOperationCreate(&mut op_state)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_addi_op(context: MlirContext, lhs: MlirValue, rhs: MlirValue) -> MlirOperation {
|
||||
create_op(context, "arith.addi", &[lhs, rhs], &[], &[], true)
|
||||
}
|
||||
|
||||
pub fn create_ret_op(context: MlirContext, ret_value: MlirValue) -> MlirOperation {
|
||||
create_op(context, "func.return", &[ret_value], &[], &[], false)
|
||||
}
|
||||
|
||||
pub fn create_constant_int_op(context: MlirContext, cst_value: i64, width: u32) -> MlirOperation {
|
||||
unsafe {
|
||||
let result_type = mlirIntegerTypeGet(context, width);
|
||||
let value_str = CString::new("value").unwrap();
|
||||
let value_attr = mlirNamedAttributeGet(
|
||||
mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())),
|
||||
mlirIntegerAttrGet(result_type, cst_value),
|
||||
);
|
||||
create_op(
|
||||
context,
|
||||
"arith.constant",
|
||||
&[],
|
||||
&[result_type],
|
||||
&[value_attr],
|
||||
true,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_constant_flat_tensor_op(
|
||||
context: MlirContext,
|
||||
cst_table: &[i64],
|
||||
bitwidth: u32,
|
||||
) -> MlirOperation {
|
||||
let shape = [cst_table.len().try_into().unwrap()];
|
||||
create_constant_tensor_op(context, &shape, cst_table, bitwidth)
|
||||
}
|
||||
|
||||
pub fn create_constant_tensor_op(
|
||||
context: MlirContext,
|
||||
shape: &[i64],
|
||||
cst_table: &[i64],
|
||||
bitwidth: u32,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let result_type = mlirRankedTensorTypeGet(
|
||||
shape.len().try_into().unwrap(),
|
||||
shape.as_ptr(),
|
||||
mlirIntegerTypeGet(context, bitwidth),
|
||||
mlirAttributeGetNull(),
|
||||
);
|
||||
let cst_table_attrs: Vec<MlirAttribute> = cst_table
|
||||
.into_iter()
|
||||
.map(|value| mlirIntegerAttrGet(mlirIntegerTypeGet(context, bitwidth), *value))
|
||||
.collect();
|
||||
let value_str = CString::new("value").unwrap();
|
||||
let value_attr = mlirNamedAttributeGet(
|
||||
mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())),
|
||||
mlirDenseElementsAttrGet(
|
||||
result_type,
|
||||
cst_table.len().try_into().unwrap(),
|
||||
cst_table_attrs.as_ptr(),
|
||||
),
|
||||
);
|
||||
create_op(
|
||||
context,
|
||||
"arith.constant",
|
||||
&[],
|
||||
&[result_type],
|
||||
&[value_attr],
|
||||
true,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn register_all_dialects(context: MlirContext) {
|
||||
let registry = mlirDialectRegistryCreate();
|
||||
mlirRegisterAllDialects(registry);
|
||||
mlirContextAppendDialectRegistry(context, registry);
|
||||
mlirContextLoadAllAvailableDialects(context);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_function_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
let func_type = mlirFunctionTypeGet(context, 0, std::ptr::null(), 0, std::ptr::null());
|
||||
assert!(mlirTypeIsAFunction(func_type));
|
||||
mlirContextDestroy(context);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module_parsing() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
let module_string = "
|
||||
module{
|
||||
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%1 = arith.addi %arg0, %arg1 : i64
|
||||
return %1: i64
|
||||
}
|
||||
}";
|
||||
let module_cstring = CString::new(module_string).unwrap();
|
||||
let module_reference = mlirStringRefCreateFromCString(module_cstring.as_ptr());
|
||||
|
||||
let parsed_module = mlirModuleCreateParse(context, module_reference);
|
||||
let parsed_func = mlirBlockGetFirstOperation(mlirModuleGetBody(parsed_module));
|
||||
|
||||
let func_type_str = CString::new("function_type").unwrap();
|
||||
// just check that we do have a function here, which should be enough to know that parsing worked well
|
||||
assert!(mlirTypeIsAFunction(mlirTypeAttrGetValue(
|
||||
mlirOperationGetAttributeByName(
|
||||
parsed_func,
|
||||
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
|
||||
)
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_module_creation() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// input/output types
|
||||
let func_input_types = [
|
||||
mlirIntegerTypeGet(context, 64),
|
||||
mlirIntegerTypeGet(context, 64),
|
||||
];
|
||||
let func_output_types = [mlirIntegerTypeGet(context, 64)];
|
||||
|
||||
let func_op = create_func_with_block(
|
||||
context,
|
||||
"test",
|
||||
func_input_types.as_slice(),
|
||||
func_output_types.as_slice(),
|
||||
);
|
||||
|
||||
let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
|
||||
let func_args = [
|
||||
mlirBlockGetArgument(func_block, 0),
|
||||
mlirBlockGetArgument(func_block, 1),
|
||||
];
|
||||
// create addi operation and append it to the block
|
||||
let addi_op = create_addi_op(context, func_args[0], func_args[1]);
|
||||
mlirBlockAppendOwnedOperation(func_block, addi_op);
|
||||
|
||||
// create ret operation and append it to the block
|
||||
let ret_op = create_ret_op(context, mlirOperationGetResult(addi_op, 0));
|
||||
mlirBlockAppendOwnedOperation(func_block, ret_op);
|
||||
|
||||
// create module to hold the previously created function
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
|
||||
|
||||
let printed_module =
|
||||
super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
|
||||
let expected_module = "\
|
||||
module {
|
||||
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
|
||||
%0 = arith.addi %arg0, %arg1 : i64
|
||||
return %0 : i64
|
||||
}
|
||||
}
|
||||
";
|
||||
assert_eq!(
|
||||
printed_module, expected_module,
|
||||
"left: \n{}, right: \n{}",
|
||||
printed_module, expected_module
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_flat_tensor() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// create a constant flat tensor
|
||||
let contant_flat_tensor_op = create_constant_flat_tensor_op(context, &[0, 1, 2, 3], 64);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(contant_flat_tensor_op);
|
||||
let expected_op = "%cst = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_tensor() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// create a constant tensor
|
||||
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0, 1, 2, 3], 64);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(contant_tensor_op);
|
||||
let expected_op = "%cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_tensor_with_signle_elem() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// create a constant tensor
|
||||
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0], 7);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(contant_tensor_op);
|
||||
let expected_op = "%cst = arith.constant dense<0> : tensor<2x2xi7>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_int() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
register_all_dialects(context);
|
||||
|
||||
// create a constant flat tensor
|
||||
let contant_int_op = create_constant_int_op(context, 73, 10);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(contant_int_op);
|
||||
let expected_op = "%c73_i10 = arith.constant 73 : i10\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,2 +0,0 @@
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Support)
|
||||
@@ -1,3 +0,0 @@
|
||||
add_subdirectory(FHE)
|
||||
add_subdirectory(FHELinalg)
|
||||
add_subdirectory(Tracing)
|
||||
@@ -1,11 +0,0 @@
|
||||
set(LLVM_OPTIONAL_SOURCES FHE.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(
|
||||
CONCRETELANGCAPIFHE
|
||||
FHE.cpp
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRCAPIIR
|
||||
FHEDialect)
|
||||
@@ -1,76 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang-c/Dialect/FHE.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/StorageUniquerSupport.h"
|
||||
|
||||
using namespace mlir::concretelang::FHE;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHE, fhe, FHEDialect)
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Type API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T>
|
||||
MlirTypeOrError IntegerTypeGetChecked(MlirContext ctx, unsigned width) {
|
||||
MlirTypeOrError type = {{NULL}, false};
|
||||
auto catchError = [&]() -> mlir::InFlightDiagnostic {
|
||||
type.isError = true;
|
||||
mlir::DiagnosticEngine &engine = unwrap(ctx)->getDiagEngine();
|
||||
// The goal here is to make getChecked working, but we don't want the CAPI
|
||||
// to stop execution due to an error, and leave the error handling logic to
|
||||
// the user of the CAPI
|
||||
return engine.emit(mlir::UnknownLoc::get(unwrap(ctx)),
|
||||
mlir::DiagnosticSeverity::Warning);
|
||||
};
|
||||
T integerType = T::getChecked(catchError, unwrap(ctx), width);
|
||||
if (type.isError) {
|
||||
return type;
|
||||
}
|
||||
type.type = wrap(integerType);
|
||||
return type;
|
||||
}
|
||||
|
||||
bool fheTypeIsAnEncryptedIntegerType(MlirType type) {
|
||||
return unwrap(type).isa<EncryptedUnsignedIntegerType>();
|
||||
}
|
||||
|
||||
MlirTypeOrError fheEncryptedIntegerTypeGetChecked(MlirContext ctx,
|
||||
unsigned width) {
|
||||
return IntegerTypeGetChecked<EncryptedUnsignedIntegerType>(ctx, width);
|
||||
}
|
||||
|
||||
bool fheTypeIsAnEncryptedSignedIntegerType(MlirType type) {
|
||||
return unwrap(type).isa<EncryptedSignedIntegerType>();
|
||||
}
|
||||
|
||||
MlirTypeOrError fheEncryptedSignedIntegerTypeGetChecked(MlirContext ctx,
|
||||
unsigned width) {
|
||||
return IntegerTypeGetChecked<EncryptedSignedIntegerType>(ctx, width);
|
||||
}
|
||||
|
||||
unsigned fheTypeIntegerWidthGet(MlirType integerType) {
|
||||
mlir::Type type = unwrap(integerType);
|
||||
auto eint = type.dyn_cast_or_null<EncryptedUnsignedIntegerType>();
|
||||
if (eint) {
|
||||
return eint.getWidth();
|
||||
}
|
||||
auto esint = type.dyn_cast_or_null<EncryptedSignedIntegerType>();
|
||||
if (esint) {
|
||||
return esint.getWidth();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
set(LLVM_OPTIONAL_SOURCES FHELinalg.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(
|
||||
CONCRETELANGCAPIFHELINALG
|
||||
FHELinalg.cpp
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRCAPIIR
|
||||
FHELinalgDialect)
|
||||
@@ -1,20 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang-c/Dialect/FHELinalg.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
|
||||
using namespace mlir::concretelang::FHELinalg;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(FHELinalg, fhelinalg, FHELinalgDialect)
|
||||
@@ -1,11 +0,0 @@
|
||||
set(LLVM_OPTIONAL_SOURCES Tracing.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(
|
||||
CONCRETELANGCAPITRACING
|
||||
Tracing.cpp
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRCAPIIR
|
||||
TracingDialect)
|
||||
@@ -1,19 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang-c/Dialect/Tracing.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Registration.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
|
||||
using namespace mlir::concretelang::Tracing;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect API.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Tracing, tracing, TracingDialect)
|
||||
@@ -1,4 +0,0 @@
|
||||
set(LLVM_OPTIONAL_SOURCES CompilerEngine.cpp)
|
||||
|
||||
add_mlir_public_c_api_library(CONCRETELANGCAPISupport CompilerEngine.cpp LINK_LIBS PUBLIC MLIRCAPIIR
|
||||
ConcretelangSupport)
|
||||
@@ -1,799 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang-c/Support/CompilerEngine.h"
|
||||
#include "concretelang/CAPI/Wrappers.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/LambdaArgument.h"
|
||||
#include "concretelang/Support/LambdaSupport.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include <numeric>
|
||||
|
||||
#define C_STRUCT_CLEANER(c_struct) \
|
||||
auto *cpp = unwrap(c_struct); \
|
||||
if (cpp != NULL) \
|
||||
delete cpp; \
|
||||
const char *error = getErrorPtr(c_struct); \
|
||||
if (error != NULL) \
|
||||
delete[] error;
|
||||
|
||||
/// ********** BufferRef CAPI **************************************************
|
||||
|
||||
BufferRef bufferRefCreate(const char *buffer, size_t length) {
|
||||
return BufferRef{buffer, length, NULL};
|
||||
}
|
||||
|
||||
BufferRef bufferRefFromString(std::string str) {
|
||||
char *buffer = new char[str.size()];
|
||||
memcpy(buffer, str.c_str(), str.size());
|
||||
return bufferRefCreate(buffer, str.size());
|
||||
}
|
||||
|
||||
BufferRef bufferRefFromStringError(std::string error) {
|
||||
char *buffer = new char[error.size()];
|
||||
memcpy(buffer, error.c_str(), error.size());
|
||||
return BufferRef{NULL, 0, buffer};
|
||||
}
|
||||
|
||||
void bufferRefDestroy(BufferRef buffer) {
|
||||
if (buffer.data != NULL)
|
||||
delete[] buffer.data;
|
||||
if (buffer.error != NULL)
|
||||
delete[] buffer.error;
|
||||
}
|
||||
|
||||
/// ********** Utilities *******************************************************
|
||||
|
||||
void mlirStringRefDestroy(MlirStringRef str) { delete[] str.data; }
|
||||
|
||||
template <typename T> BufferRef serialize(T toSerialize) {
|
||||
std::ostringstream ostream(std::ios::binary);
|
||||
auto voidOrError = unwrap(toSerialize)->serialize(ostream);
|
||||
if (voidOrError.has_error()) {
|
||||
return bufferRefFromStringError(voidOrError.error().mesg);
|
||||
}
|
||||
return bufferRefFromString(ostream.str());
|
||||
}
|
||||
|
||||
/// ********** CompilationOptions CAPI *****************************************
|
||||
|
||||
CompilationOptions
|
||||
compilationOptionsCreate(MlirStringRef funcName, bool autoParallelize,
|
||||
bool batchTFHEOps, bool dataflowParallelize,
|
||||
bool emitGPUOps, bool loopParallelize,
|
||||
bool optimizeTFHE, OptimizerConfig optimizerConfig,
|
||||
bool verifyDiagnostics) {
|
||||
std::string funcNameStr(funcName.data, funcName.length);
|
||||
auto options = new mlir::concretelang::CompilationOptions(funcNameStr);
|
||||
options->autoParallelize = autoParallelize;
|
||||
options->batchTFHEOps = batchTFHEOps;
|
||||
options->dataflowParallelize = dataflowParallelize;
|
||||
options->emitGPUOps = emitGPUOps;
|
||||
options->loopParallelize = loopParallelize;
|
||||
options->optimizeTFHE = optimizeTFHE;
|
||||
options->optimizerConfig = *unwrap(optimizerConfig);
|
||||
options->verifyDiagnostics = verifyDiagnostics;
|
||||
return wrap(options);
|
||||
}
|
||||
|
||||
CompilationOptions compilationOptionsCreateDefault() {
|
||||
return wrap(new mlir::concretelang::CompilationOptions("main"));
|
||||
}
|
||||
|
||||
void compilationOptionsDestroy(CompilationOptions options){
|
||||
C_STRUCT_CLEANER(options)}
|
||||
|
||||
/// ********** OptimizerConfig CAPI ********************************************
|
||||
|
||||
OptimizerConfig
|
||||
optimizerConfigCreate(bool display, double fallback_log_norm_woppbs,
|
||||
double global_p_error, double p_error,
|
||||
uint64_t security,
|
||||
mlir::concretelang::optimizer::Strategy strategy,
|
||||
bool use_gpu_constraints,
|
||||
uint32_t ciphertext_modulus_log,
|
||||
uint32_t fft_precision) {
|
||||
auto config = new mlir::concretelang::optimizer::Config();
|
||||
config->display = display;
|
||||
config->fallback_log_norm_woppbs = fallback_log_norm_woppbs;
|
||||
config->global_p_error = global_p_error;
|
||||
config->p_error = p_error;
|
||||
config->security = security;
|
||||
config->strategy = strategy;
|
||||
config->use_gpu_constraints = use_gpu_constraints;
|
||||
config->ciphertext_modulus_log = ciphertext_modulus_log;
|
||||
config->fft_precision = fft_precision;
|
||||
return wrap(config);
|
||||
}
|
||||
|
||||
OptimizerConfig optimizerConfigCreateDefault() {
|
||||
return wrap(new mlir::concretelang::optimizer::Config());
|
||||
}
|
||||
|
||||
void optimizerConfigDestroy(OptimizerConfig config){C_STRUCT_CLEANER(config)}
|
||||
|
||||
/// ********** CompilerEngine CAPI *********************************************
|
||||
|
||||
CompilerEngine compilerEngineCreate() {
|
||||
auto *engine = new mlir::concretelang::CompilerEngine(
|
||||
mlir::concretelang::CompilationContext::createShared());
|
||||
return wrap(engine);
|
||||
}
|
||||
|
||||
void compilerEngineDestroy(CompilerEngine engine){C_STRUCT_CLEANER(engine)}
|
||||
|
||||
/// Map C compilationTarget to Cpp
|
||||
llvm::Expected<mlir::concretelang::CompilerEngine::
|
||||
Target> targetConvertToCppFromC(CompilationTarget target) {
|
||||
switch (target) {
|
||||
case ROUND_TRIP:
|
||||
return mlir::concretelang::CompilerEngine::Target::ROUND_TRIP;
|
||||
case FHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::FHE;
|
||||
case TFHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::TFHE;
|
||||
case PARAMETRIZED_TFHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::PARAMETRIZED_TFHE;
|
||||
case NORMALIZED_TFHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::NORMALIZED_TFHE;
|
||||
case BATCHED_TFHE:
|
||||
return mlir::concretelang::CompilerEngine::Target::BATCHED_TFHE;
|
||||
case CONCRETE:
|
||||
return mlir::concretelang::CompilerEngine::Target::CONCRETE;
|
||||
case STD:
|
||||
return mlir::concretelang::CompilerEngine::Target::STD;
|
||||
case LLVM:
|
||||
return mlir::concretelang::CompilerEngine::Target::LLVM;
|
||||
case LLVM_IR:
|
||||
return mlir::concretelang::CompilerEngine::Target::LLVM_IR;
|
||||
case OPTIMIZED_LLVM_IR:
|
||||
return mlir::concretelang::CompilerEngine::Target::OPTIMIZED_LLVM_IR;
|
||||
case LIBRARY:
|
||||
return mlir::concretelang::CompilerEngine::Target::LIBRARY;
|
||||
}
|
||||
return mlir::concretelang::StreamStringError("invalid compilation target");
|
||||
}
|
||||
|
||||
CompilationResult compilerEngineCompile(CompilerEngine engine,
|
||||
MlirStringRef module,
|
||||
CompilationTarget target) {
|
||||
std::string module_str(module.data, module.length);
|
||||
auto targetCppOrError = targetConvertToCppFromC(target);
|
||||
if (!targetCppOrError) { // invalid target
|
||||
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL,
|
||||
llvm::toString(targetCppOrError.takeError()));
|
||||
}
|
||||
auto retOrError = unwrap(engine)->compile(module_str, targetCppOrError.get());
|
||||
if (!retOrError) { // compilation error
|
||||
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL,
|
||||
llvm::toString(retOrError.takeError()));
|
||||
}
|
||||
return wrap(new mlir::concretelang::CompilerEngine::CompilationResult(
|
||||
std::move(retOrError.get())));
|
||||
}
|
||||
|
||||
void compilerEngineCompileSetOptions(CompilerEngine engine,
|
||||
CompilationOptions options) {
|
||||
unwrap(engine)->setCompilationOptions(*unwrap(options));
|
||||
}
|
||||
|
||||
/// ********** CompilationResult CAPI ******************************************
|
||||
|
||||
MlirStringRef compilationResultGetModuleString(CompilationResult result) {
|
||||
// print the module into a string
|
||||
std::string moduleString;
|
||||
llvm::raw_string_ostream os(moduleString);
|
||||
unwrap(result)->mlirModuleRef->get().print(os);
|
||||
// allocate buffer and copy module string
|
||||
char *buffer = new char[moduleString.length() + 1];
|
||||
strcpy(buffer, moduleString.c_str());
|
||||
return mlirStringRefCreate(buffer, moduleString.length());
|
||||
}
|
||||
|
||||
void compilationResultDestroyModuleString(MlirStringRef str) {
|
||||
mlirStringRefDestroy(str);
|
||||
}
|
||||
|
||||
void compilationResultDestroy(CompilationResult result){
|
||||
C_STRUCT_CLEANER(result)}
|
||||
|
||||
/// ********** Library CAPI ****************************************************
|
||||
|
||||
Library libraryCreate(MlirStringRef outputDirPath,
|
||||
MlirStringRef runtimeLibraryPath, bool cleanUp) {
|
||||
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
|
||||
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
|
||||
runtimeLibraryPath.length);
|
||||
return wrap(new mlir::concretelang::CompilerEngine::Library(
|
||||
outputDirPathStr, runtimeLibraryPathStr, cleanUp));
|
||||
}
|
||||
|
||||
void libraryDestroy(Library lib) { C_STRUCT_CLEANER(lib) }
|
||||
|
||||
/// ********** LibraryCompilationResult CAPI ***********************************
|
||||
|
||||
void libraryCompilationResultDestroy(LibraryCompilationResult result){
|
||||
C_STRUCT_CLEANER(result)}
|
||||
|
||||
/// ********** LibrarySupport CAPI *********************************************
|
||||
|
||||
LibrarySupport
|
||||
librarySupportCreate(MlirStringRef outputDirPath,
|
||||
MlirStringRef runtimeLibraryPath,
|
||||
bool generateSharedLib, bool generateStaticLib,
|
||||
bool generateClientParameters,
|
||||
bool generateCompilationFeedback,
|
||||
bool generateCppHeader) {
|
||||
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
|
||||
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
|
||||
runtimeLibraryPath.length);
|
||||
return wrap(new mlir::concretelang::LibrarySupport(
|
||||
outputDirPathStr, runtimeLibraryPathStr, generateSharedLib,
|
||||
generateStaticLib, generateClientParameters, generateCompilationFeedback,
|
||||
generateCppHeader));
|
||||
}
|
||||
|
||||
LibraryCompilationResult librarySupportCompile(LibrarySupport support,
|
||||
MlirStringRef module,
|
||||
CompilationOptions options) {
|
||||
std::string moduleStr(module.data, module.length);
|
||||
auto retOrError = unwrap(support)->compile(moduleStr, *unwrap(options));
|
||||
if (!retOrError) {
|
||||
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL,
|
||||
llvm::toString(retOrError.takeError()));
|
||||
}
|
||||
return wrap(new mlir::concretelang::LibraryCompilationResult(
|
||||
*retOrError.get().release()));
|
||||
}
|
||||
|
||||
ServerLambda librarySupportLoadServerLambda(LibrarySupport support,
|
||||
LibraryCompilationResult result) {
|
||||
auto serverLambdaOrError = unwrap(support)->loadServerLambda(*unwrap(result));
|
||||
if (!serverLambdaOrError) {
|
||||
return wrap((mlir::concretelang::serverlib::ServerLambda *)NULL,
|
||||
llvm::toString(serverLambdaOrError.takeError()));
|
||||
}
|
||||
return wrap(new mlir::concretelang::serverlib::ServerLambda(
|
||||
serverLambdaOrError.get()));
|
||||
}
|
||||
|
||||
ClientParameters
|
||||
librarySupportLoadClientParameters(LibrarySupport support,
|
||||
LibraryCompilationResult result) {
|
||||
auto paramsOrError = unwrap(support)->loadClientParameters(*unwrap(result));
|
||||
if (!paramsOrError) {
|
||||
return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL,
|
||||
llvm::toString(paramsOrError.takeError()));
|
||||
}
|
||||
return wrap(
|
||||
new mlir::concretelang::clientlib::ClientParameters(paramsOrError.get()));
|
||||
}
|
||||
|
||||
LibraryCompilationResult
|
||||
librarySupportLoadCompilationResult(LibrarySupport support) {
|
||||
auto retOrError = unwrap(support)->loadCompilationResult();
|
||||
if (!retOrError) {
|
||||
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL,
|
||||
llvm::toString(retOrError.takeError()));
|
||||
}
|
||||
return wrap(new mlir::concretelang::LibraryCompilationResult(
|
||||
*retOrError.get().release()));
|
||||
}
|
||||
|
||||
CompilationFeedback
|
||||
librarySupportLoadCompilationFeedback(LibrarySupport support,
|
||||
LibraryCompilationResult result) {
|
||||
auto feedbackOrError =
|
||||
unwrap(support)->loadCompilationFeedback(*unwrap(result));
|
||||
if (!feedbackOrError) {
|
||||
return wrap((mlir::concretelang::CompilationFeedback *)NULL,
|
||||
llvm::toString(feedbackOrError.takeError()));
|
||||
}
|
||||
return wrap(
|
||||
new mlir::concretelang::CompilationFeedback(feedbackOrError.get()));
|
||||
}
|
||||
|
||||
PublicResult librarySupportServerCall(LibrarySupport support,
|
||||
ServerLambda server_lambda,
|
||||
PublicArguments args,
|
||||
EvaluationKeys evalKeys) {
|
||||
auto resultOrError = unwrap(support)->serverCall(
|
||||
*unwrap(server_lambda), *unwrap(args), *unwrap(evalKeys));
|
||||
if (!resultOrError) {
|
||||
return wrap((mlir::concretelang::clientlib::PublicResult *)NULL,
|
||||
llvm::toString(resultOrError.takeError()));
|
||||
}
|
||||
return wrap(resultOrError.get().release());
|
||||
}
|
||||
|
||||
MlirStringRef librarySupportGetSharedLibPath(LibrarySupport support) {
|
||||
auto path = unwrap(support)->getSharedLibPath();
|
||||
// allocate buffer and copy module string
|
||||
char *buffer = new char[path.length() + 1];
|
||||
strcpy(buffer, path.c_str());
|
||||
return mlirStringRefCreate(buffer, path.length());
|
||||
}
|
||||
|
||||
MlirStringRef librarySupportGetClientParametersPath(LibrarySupport support) {
|
||||
auto path = unwrap(support)->getClientParametersPath();
|
||||
// allocate buffer and copy module string
|
||||
char *buffer = new char[path.length() + 1];
|
||||
strcpy(buffer, path.c_str());
|
||||
return mlirStringRefCreate(buffer, path.length());
|
||||
}
|
||||
|
||||
void librarySupportDestroy(LibrarySupport support) { C_STRUCT_CLEANER(support) }
|
||||
|
||||
/// ********** ServerLamda CAPI ************************************************
|
||||
|
||||
void serverLambdaDestroy(ServerLambda server){C_STRUCT_CLEANER(server)}
|
||||
|
||||
/// ********** ClientParameters CAPI *******************************************
|
||||
|
||||
BufferRef clientParametersSerialize(ClientParameters params) {
|
||||
llvm::json::Value value(*unwrap(params));
|
||||
std::string jsonParams;
|
||||
llvm::raw_string_ostream ostream(jsonParams);
|
||||
ostream << value;
|
||||
char *buffer = new char[jsonParams.size() + 1];
|
||||
strcpy(buffer, jsonParams.c_str());
|
||||
return bufferRefCreate(buffer, jsonParams.size());
|
||||
}
|
||||
|
||||
ClientParameters clientParametersUnserialize(BufferRef buffer) {
|
||||
std::string json(buffer.data, buffer.length);
|
||||
auto paramsOrError =
|
||||
llvm::json::parse<mlir::concretelang::ClientParameters>(json);
|
||||
if (!paramsOrError) {
|
||||
return wrap((mlir::concretelang::ClientParameters *)NULL,
|
||||
llvm::toString(paramsOrError.takeError()));
|
||||
}
|
||||
return wrap(new mlir::concretelang::ClientParameters(paramsOrError.get()));
|
||||
}
|
||||
|
||||
ClientParameters clientParametersCopy(ClientParameters params) {
|
||||
return wrap(new mlir::concretelang::ClientParameters(*unwrap(params)));
|
||||
}
|
||||
|
||||
void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)}
|
||||
|
||||
size_t clientParametersOutputsSize(ClientParameters params) {
|
||||
return unwrap(params)->outputs.size();
|
||||
}
|
||||
|
||||
size_t clientParametersInputsSize(ClientParameters params) {
|
||||
return unwrap(params)->inputs.size();
|
||||
}
|
||||
|
||||
CircuitGate clientParametersOutputCircuitGate(ClientParameters params,
|
||||
size_t index) {
|
||||
auto &cppGate = unwrap(params)->outputs[index];
|
||||
auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate);
|
||||
return wrap(cppGateCopy);
|
||||
}
|
||||
|
||||
CircuitGate clientParametersInputCircuitGate(ClientParameters params,
|
||||
size_t index) {
|
||||
auto &cppGate = unwrap(params)->inputs[index];
|
||||
auto *cppGateCopy = new mlir::concretelang::clientlib::CircuitGate(cppGate);
|
||||
return wrap(cppGateCopy);
|
||||
}
|
||||
|
||||
EncryptionGate circuitGateEncryptionGate(CircuitGate circuit_gate) {
|
||||
auto &maybe_gate = unwrap(circuit_gate)->encryption;
|
||||
|
||||
if (maybe_gate) {
|
||||
auto *copy = new mlir::concretelang::clientlib::EncryptionGate(*maybe_gate);
|
||||
return wrap(copy);
|
||||
}
|
||||
return (static_cast<EncryptionGate (*)(
|
||||
mlir::concretelang::clientlib::EncryptionGate *)>(wrap))(nullptr);
|
||||
}
|
||||
|
||||
double encryptionGateVariance(EncryptionGate encryption_gate) {
|
||||
return unwrap(encryption_gate)->variance;
|
||||
}
|
||||
|
||||
Encoding encryptionGateEncoding(EncryptionGate encryption_gate) {
|
||||
auto &cppEncoding = unwrap(encryption_gate)->encoding;
|
||||
auto *copy = new mlir::concretelang::clientlib::Encoding(cppEncoding);
|
||||
return wrap(copy);
|
||||
}
|
||||
|
||||
uint64_t encodingPrecision(Encoding encoding) {
|
||||
return unwrap(encoding)->precision;
|
||||
}
|
||||
|
||||
void circuitGateDestroy(CircuitGate gate) { C_STRUCT_CLEANER(gate) }
|
||||
void encryptionGateDestroy(EncryptionGate gate) { C_STRUCT_CLEANER(gate) }
|
||||
void encodingDestroy(Encoding encoding){C_STRUCT_CLEANER(encoding)}
|
||||
|
||||
/// ********** KeySet CAPI *****************************************************
|
||||
|
||||
KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
__uint128_t seed = seed_msb;
|
||||
seed <<= 64;
|
||||
seed += seed_lsb;
|
||||
|
||||
auto csprng = concretelang::clientlib::ConcreteCSPRNG(seed);
|
||||
|
||||
auto keySet = mlir::concretelang::clientlib::KeySet::generate(
|
||||
*unwrap(params), std::move(csprng));
|
||||
if (keySet.has_error()) {
|
||||
return wrap((mlir::concretelang::clientlib::KeySet *)NULL,
|
||||
keySet.error().mesg);
|
||||
}
|
||||
return wrap(keySet.value().release());
|
||||
}
|
||||
|
||||
EvaluationKeys keySetGetEvaluationKeys(KeySet keySet) {
|
||||
return wrap(new mlir::concretelang::clientlib::EvaluationKeys(
|
||||
unwrap(keySet)->evaluationKeys()));
|
||||
}
|
||||
|
||||
void keySetDestroy(KeySet keySet){C_STRUCT_CLEANER(keySet)}
|
||||
|
||||
/// ********** KeySetCache CAPI ************************************************
|
||||
|
||||
KeySetCache keySetCacheCreate(MlirStringRef cachePath) {
|
||||
std::string cachePathStr(cachePath.data, cachePath.length);
|
||||
return wrap(new mlir::concretelang::clientlib::KeySetCache(cachePathStr));
|
||||
}
|
||||
|
||||
KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache,
|
||||
ClientParameters params,
|
||||
uint64_t seed_msb, uint64_t seed_lsb) {
|
||||
auto keySetOrError =
|
||||
unwrap(cache)->generate(*unwrap(params), seed_msb, seed_lsb);
|
||||
if (keySetOrError.has_error()) {
|
||||
return wrap((mlir::concretelang::clientlib::KeySet *)NULL,
|
||||
keySetOrError.error().mesg);
|
||||
}
|
||||
return wrap(keySetOrError.value().release());
|
||||
}
|
||||
|
||||
void keySetCacheDestroy(KeySetCache keySetCache){C_STRUCT_CLEANER(keySetCache)}
|
||||
|
||||
/// ********** EvaluationKeys CAPI *********************************************
|
||||
|
||||
BufferRef evaluationKeysSerialize(EvaluationKeys keys) {
|
||||
std::ostringstream ostream(std::ios::binary);
|
||||
concretelang::clientlib::operator<<(ostream, *unwrap(keys));
|
||||
if (ostream.fail()) {
|
||||
return bufferRefFromStringError(
|
||||
"output stream failure during evaluation keys serialization");
|
||||
}
|
||||
return bufferRefFromString(ostream.str());
|
||||
}
|
||||
|
||||
EvaluationKeys evaluationKeysUnserialize(BufferRef buffer) {
|
||||
std::stringstream istream(std::string(buffer.data, buffer.length));
|
||||
concretelang::clientlib::EvaluationKeys evaluationKeys =
|
||||
concretelang::clientlib::readEvaluationKeys(istream);
|
||||
if (istream.fail()) {
|
||||
return wrap((concretelang::clientlib::EvaluationKeys *)NULL,
|
||||
"input stream failure during evaluation keys unserialization");
|
||||
}
|
||||
return wrap(new concretelang::clientlib::EvaluationKeys(evaluationKeys));
|
||||
}
|
||||
|
||||
void evaluationKeysDestroy(EvaluationKeys evaluationKeys) {
|
||||
C_STRUCT_CLEANER(evaluationKeys);
|
||||
}
|
||||
|
||||
/// ********** LambdaArgument CAPI *********************************************
|
||||
|
||||
LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
|
||||
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
|
||||
}
|
||||
|
||||
int64_t getSizeFromRankAndDims(size_t rank, const int64_t *dims) {
|
||||
if (rank == 0) // not a tensor
|
||||
return 1;
|
||||
auto size = dims[0];
|
||||
for (size_t i = 1; i < rank; i++)
|
||||
size *= dims[i];
|
||||
return size;
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU8(const uint8_t *data,
|
||||
const int64_t *dims, size_t rank) {
|
||||
|
||||
std::vector<uint8_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU16(const uint16_t *data,
|
||||
const int64_t *dims, size_t rank) {
|
||||
|
||||
std::vector<uint16_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU32(const uint32_t *data,
|
||||
const int64_t *dims, size_t rank) {
|
||||
|
||||
std::vector<uint32_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
LambdaArgument lambdaArgumentFromTensorU64(const uint64_t *data,
|
||||
const int64_t *dims, size_t rank) {
|
||||
|
||||
std::vector<uint64_t> data_vector(data,
|
||||
data + getSizeFromRankAndDims(rank, dims));
|
||||
std::vector<int64_t> dims_vector(dims, dims + rank);
|
||||
return wrap(new mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>(data_vector,
|
||||
dims_vector));
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) {
|
||||
return unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
}
|
||||
|
||||
uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg) {
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t> *arg =
|
||||
unwrap(lambdaArg)
|
||||
->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
assert(arg != nullptr && "lambda argument isn't a scalar");
|
||||
return arg->getValue();
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) {
|
||||
return unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
|
||||
unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
|
||||
unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
|
||||
unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool copyTensorDataToBuffer(
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<T>> *tensor,
|
||||
uint64_t *buffer) {
|
||||
auto *data = tensor->getValue();
|
||||
auto sizeOrError = tensor->getNumElements();
|
||||
if (!sizeOrError) {
|
||||
llvm::errs() << llvm::toString(sizeOrError.takeError());
|
||||
return false;
|
||||
}
|
||||
auto size = sizeOrError.get();
|
||||
for (size_t i = 0; i < size; i++)
|
||||
buffer[i] = data[i];
|
||||
return true;
|
||||
}
|
||||
|
||||
bool lambdaArgumentGetTensorData(LambdaArgument lambdaArg, uint64_t *buffer) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return copyTensorDataToBuffer(tensor, buffer);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return tensor->getDimensions().size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64_t lambdaArgumentGetTensorDataSize(LambdaArgument lambdaArg) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
std::vector<int64_t> dims;
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
return std::accumulate(std::begin(dims), std::end(dims), 1,
|
||||
std::multiplies<int64_t>());
|
||||
}
|
||||
|
||||
bool lambdaArgumentGetTensorDims(LambdaArgument lambdaArg, int64_t *buffer) {
|
||||
auto arg = unwrap(lambdaArg);
|
||||
std::vector<int64_t> dims;
|
||||
if (auto tensor = arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else if (auto tensor =
|
||||
arg->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
dims = tensor->getDimensions();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
memcpy(buffer, dims.data(), sizeof(int64_t) * dims.size());
|
||||
return true;
|
||||
}
|
||||
|
||||
PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
|
||||
size_t argNumber, ClientParameters params,
|
||||
KeySet keySet) {
|
||||
std::vector<const mlir::concretelang::LambdaArgument *> args;
|
||||
for (size_t i = 0; i < argNumber; i++)
|
||||
args.push_back(unwrap(lambdaArgs[i]));
|
||||
auto publicArgsOrError =
|
||||
mlir::concretelang::LambdaSupport<int, int>::exportArguments(
|
||||
*unwrap(params), *unwrap(keySet), args);
|
||||
if (!publicArgsOrError) {
|
||||
return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL,
|
||||
llvm::toString(publicArgsOrError.takeError()));
|
||||
}
|
||||
return wrap(publicArgsOrError.get().release());
|
||||
}
|
||||
|
||||
void lambdaArgumentDestroy(LambdaArgument lambdaArg){
|
||||
C_STRUCT_CLEANER(lambdaArg)}
|
||||
|
||||
/// ********** PublicArguments CAPI ********************************************
|
||||
|
||||
BufferRef publicArgumentsSerialize(PublicArguments args) {
|
||||
return serialize(args);
|
||||
}
|
||||
|
||||
PublicArguments publicArgumentsUnserialize(BufferRef buffer,
|
||||
ClientParameters params) {
|
||||
std::stringstream istream(std::string(buffer.data, buffer.length));
|
||||
auto argsOrError = concretelang::clientlib::PublicArguments::unserialize(
|
||||
*unwrap(params), istream);
|
||||
if (!argsOrError) {
|
||||
return wrap((concretelang::clientlib::PublicArguments *)NULL,
|
||||
argsOrError.error().mesg);
|
||||
}
|
||||
return wrap(argsOrError.value().release());
|
||||
}
|
||||
|
||||
void publicArgumentsDestroy(PublicArguments publicArgs){
|
||||
C_STRUCT_CLEANER(publicArgs)}
|
||||
|
||||
/// ********** PublicResult CAPI ***********************************************
|
||||
|
||||
LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) {
|
||||
llvm::Expected<std::unique_ptr<mlir::concretelang::LambdaArgument>>
|
||||
lambdaArgOrError = mlir::concretelang::typedResult<
|
||||
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
|
||||
*unwrap(keySet), *unwrap(publicResult));
|
||||
if (!lambdaArgOrError) {
|
||||
return wrap((mlir::concretelang::LambdaArgument *)NULL,
|
||||
llvm::toString(lambdaArgOrError.takeError()));
|
||||
}
|
||||
return wrap(lambdaArgOrError.get().release());
|
||||
}
|
||||
|
||||
BufferRef publicResultSerialize(PublicResult result) {
|
||||
return serialize(result);
|
||||
}
|
||||
|
||||
PublicResult publicResultUnserialize(BufferRef buffer,
|
||||
ClientParameters params) {
|
||||
std::stringstream istream(std::string(buffer.data, buffer.length));
|
||||
auto resultOrError = concretelang::clientlib::PublicResult::unserialize(
|
||||
*unwrap(params), istream);
|
||||
if (!resultOrError) {
|
||||
return wrap((concretelang::clientlib::PublicResult *)NULL,
|
||||
resultOrError.error().mesg);
|
||||
}
|
||||
return wrap(resultOrError.value().release());
|
||||
}
|
||||
|
||||
void publicResultDestroy(PublicResult publicResult) {
|
||||
C_STRUCT_CLEANER(publicResult)
|
||||
}
|
||||
|
||||
/// ********** CompilationFeedback CAPI ****************************************
|
||||
|
||||
double compilationFeedbackGetComplexity(CompilationFeedback feedback) {
|
||||
return unwrap(feedback)->complexity;
|
||||
}
|
||||
|
||||
double compilationFeedbackGetPError(CompilationFeedback feedback) {
|
||||
return unwrap(feedback)->pError;
|
||||
}
|
||||
|
||||
double compilationFeedbackGetGlobalPError(CompilationFeedback feedback) {
|
||||
return unwrap(feedback)->globalPError;
|
||||
}
|
||||
|
||||
uint64_t
|
||||
compilationFeedbackGetTotalSecretKeysSize(CompilationFeedback feedback) {
|
||||
return unwrap(feedback)->totalSecretKeysSize;
|
||||
}
|
||||
|
||||
uint64_t
|
||||
compilationFeedbackGetTotalBootstrapKeysSize(CompilationFeedback feedback) {
|
||||
return unwrap(feedback)->totalBootstrapKeysSize;
|
||||
}
|
||||
|
||||
uint64_t
|
||||
compilationFeedbackGetTotalKeyswitchKeysSize(CompilationFeedback feedback) {
|
||||
return unwrap(feedback)->totalKeyswitchKeysSize;
|
||||
}
|
||||
|
||||
uint64_t compilationFeedbackGetTotalInputsSize(CompilationFeedback feedback) {
|
||||
return unwrap(feedback)->totalInputsSize;
|
||||
}
|
||||
|
||||
uint64_t compilationFeedbackGetTotalOutputsSize(CompilationFeedback feedback) {
|
||||
return unwrap(feedback)->totalOutputsSize;
|
||||
}
|
||||
|
||||
void compilationFeedbackDestroy(CompilationFeedback feedback) {
|
||||
C_STRUCT_CLEANER(feedback)
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
add_subdirectory(Analysis)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Common)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Transforms)
|
||||
add_subdirectory(Support)
|
||||
@@ -8,4 +9,3 @@ add_subdirectory(ClientLib)
|
||||
add_subdirectory(Bindings)
|
||||
add_subdirectory(ServerLib)
|
||||
add_subdirectory(Interfaces)
|
||||
add_subdirectory(CAPI)
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
add_compile_options(-fexceptions)
|
||||
|
||||
add_mlir_library(
|
||||
ConcretelangClientLib
|
||||
ClientLambda.cpp
|
||||
ClientParameters.cpp
|
||||
EvaluationKeys.cpp
|
||||
CRT.cpp
|
||||
EncryptedArguments.cpp
|
||||
KeySet.cpp
|
||||
KeySetCache.cpp
|
||||
PublicArguments.cpp
|
||||
Serializers.cpp
|
||||
ClientLib.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/ClientLib
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Common
|
||||
LINK_LIBS
|
||||
ConcretelangCommon
|
||||
PUBLIC
|
||||
concrete_cpu)
|
||||
concrete_cpu
|
||||
concrete-protocol)
|
||||
|
||||
target_include_directories(ConcretelangClientLib PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
|
||||
|
||||
@@ -1,185 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <dlfcn.h>
|
||||
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
outcome::checked<ClientLambda, StringError>
|
||||
ClientLambda::load(std::string functionName, std::string jsonPath) {
|
||||
OUTCOME_TRY(auto all_params, ClientParameters::load(jsonPath));
|
||||
auto param = llvm::find_if(all_params, [&](ClientParameters param) {
|
||||
return param.functionName == functionName;
|
||||
});
|
||||
|
||||
if (param == all_params.end()) {
|
||||
return StringError("ClientLambda: cannot find function ")
|
||||
<< functionName << " in client parameters" << jsonPath;
|
||||
}
|
||||
if (param->outputs.size() != 1) {
|
||||
return StringError("ClientLambda: output arity (")
|
||||
<< std::to_string(param->outputs.size()) << ") != 1 is not supprted";
|
||||
}
|
||||
|
||||
if (!param->outputs[0].encryption.has_value()) {
|
||||
return StringError("ClientLambda: clear output is not yet supported");
|
||||
}
|
||||
ClientLambda lambda;
|
||||
lambda.clientParameters = *param;
|
||||
return lambda;
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
ClientLambda::keySet(std::shared_ptr<KeySetCache> optionalCache,
|
||||
uint64_t seed_msb, uint64_t seed_lsb) {
|
||||
return KeySetCache::generate(optionalCache, clientParameters, seed_msb,
|
||||
seed_lsb);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
|
||||
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, result));
|
||||
return v[0];
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
|
||||
return result.asClearTextVector<decrypted_scalar_t>(keySet, 0);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> errorResultRank(size_t expected,
|
||||
size_t actual) {
|
||||
return StringError("Expected result has rank ")
|
||||
<< expected << " and cannot be converted to rank " << actual;
|
||||
}
|
||||
|
||||
StringError errorIncoherentSizes(size_t flatSize, size_t structuredSize) {
|
||||
return StringError("Received ")
|
||||
<< flatSize << " values but is sizes indicates as global size of "
|
||||
<< structuredSize;
|
||||
}
|
||||
|
||||
template <typename DecryptedTensor>
|
||||
DecryptedTensor flatToTensor(decrypted_tensor_1_t &values, size_t *sizes);
|
||||
|
||||
template <>
|
||||
decrypted_tensor_1_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
|
||||
return values;
|
||||
}
|
||||
|
||||
template <>
|
||||
decrypted_tensor_2_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
|
||||
decrypted_tensor_2_t result(sizes[0]);
|
||||
size_t position = 0;
|
||||
for (auto &dest0 : result) {
|
||||
dest0.resize(sizes[1]);
|
||||
for (auto &dest1 : dest0) {
|
||||
dest1 = values[position++];
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
|
||||
decrypted_tensor_3_t result(sizes[0]);
|
||||
size_t position = 0;
|
||||
for (auto &dest0 : result) {
|
||||
dest0.resize(sizes[1]);
|
||||
for (auto &dest1 : dest0) {
|
||||
dest1.resize(sizes[2]);
|
||||
for (auto &dest2 : dest1) {
|
||||
dest2 = values[position++];
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename DecryptedTensor>
|
||||
outcome::checked<DecryptedTensor, StringError>
|
||||
decryptReturnedTensor(PublicResult &result, ClientLambda &lambda,
|
||||
ClientParameters ¶ms, size_t expectedRank,
|
||||
KeySet &keySet) {
|
||||
auto shape = params.outputs[0].shape;
|
||||
size_t rank = shape.dimensions.size();
|
||||
if (rank != expectedRank) {
|
||||
return StringError("Function returns a tensor of rank ")
|
||||
<< expectedRank << " which cannot be decrypted to rank " << rank;
|
||||
}
|
||||
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, result));
|
||||
llvm::SmallVector<size_t, 6> sizes;
|
||||
for (size_t dim = 0; dim < rank; dim++) {
|
||||
sizes.push_back(shape.dimensions[dim]);
|
||||
}
|
||||
return flatToTensor<DecryptedTensor>(values, sizes.data());
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor1(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_1_t>(
|
||||
result, *this, this->clientParameters, 1, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor2(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_2_t>(
|
||||
result, *this, this->clientParameters, 2, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor3(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_3_t>(
|
||||
result, *this, this->clientParameters, 3, keySet);
|
||||
}
|
||||
|
||||
template <typename Result>
|
||||
outcome::checked<Result, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
PublicResult &result) {
|
||||
// compile time error if used
|
||||
using COMPATIBLE_RESULT_TYPE = void;
|
||||
return (Result)(COMPATIBLE_RESULT_TYPE)0;
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedScalar(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor1(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor2(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_3_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor3(keySet, result);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
132
compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp
Normal file
132
compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp
Normal file
@@ -0,0 +1,132 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include "concrete-cpu.h"
|
||||
#include "concrete-protocol.capnp.h"
|
||||
#include "concretelang/ClientLib/ClientLib.h"
|
||||
#include "concretelang/Common/Csprng.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Common/Keysets.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include "concretelang/Common/Transformers.h"
|
||||
#include "concretelang/Common/Values.h"
|
||||
|
||||
using concretelang::error::Result;
|
||||
using concretelang::keysets::ClientKeyset;
|
||||
using concretelang::transformers::InputTransformer;
|
||||
using concretelang::transformers::OutputTransformer;
|
||||
using concretelang::transformers::TransformerFactory;
|
||||
using concretelang::values::TransportValue;
|
||||
using concretelang::values::Value;
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
Result<ClientCircuit>
|
||||
ClientCircuit::create(const Message<concreteprotocol::CircuitInfo> &info,
|
||||
const ClientKeyset &keyset,
|
||||
std::shared_ptr<CSPRNG> csprng, bool useSimulation) {
|
||||
|
||||
auto inputTransformers = std::vector<InputTransformer>();
|
||||
|
||||
for (auto gateInfo : info.asReader().getInputs()) {
|
||||
InputTransformer transformer;
|
||||
if (gateInfo.getTypeInfo().hasIndex()) {
|
||||
OUTCOME_TRY(transformer,
|
||||
TransformerFactory::getIndexInputTransformer(gateInfo));
|
||||
} else if (gateInfo.getTypeInfo().hasPlaintext()) {
|
||||
OUTCOME_TRY(transformer,
|
||||
TransformerFactory::getPlaintextInputTransformer(gateInfo));
|
||||
} else if (gateInfo.getTypeInfo().hasLweCiphertext()) {
|
||||
OUTCOME_TRY(transformer,
|
||||
TransformerFactory::getLweCiphertextInputTransformer(
|
||||
keyset, gateInfo, csprng, useSimulation));
|
||||
} else {
|
||||
return StringError("Malformed input gate info.");
|
||||
}
|
||||
inputTransformers.push_back(transformer);
|
||||
}
|
||||
|
||||
auto outputTransformers = std::vector<OutputTransformer>();
|
||||
|
||||
for (auto gateInfo : info.asReader().getOutputs()) {
|
||||
OutputTransformer transformer;
|
||||
if (gateInfo.getTypeInfo().hasIndex()) {
|
||||
OUTCOME_TRY(transformer,
|
||||
TransformerFactory::getIndexOutputTransformer(gateInfo));
|
||||
} else if (gateInfo.getTypeInfo().hasPlaintext()) {
|
||||
OUTCOME_TRY(transformer,
|
||||
TransformerFactory::getPlaintextOutputTransformer(gateInfo));
|
||||
} else if (gateInfo.getTypeInfo().hasLweCiphertext()) {
|
||||
OUTCOME_TRY(transformer,
|
||||
TransformerFactory::getLweCiphertextOutputTransformer(
|
||||
keyset, gateInfo, useSimulation));
|
||||
} else {
|
||||
return StringError("Malformed output gate info.");
|
||||
}
|
||||
outputTransformers.push_back(transformer);
|
||||
}
|
||||
|
||||
return ClientCircuit(info, inputTransformers, outputTransformers);
|
||||
}
|
||||
|
||||
Result<TransportValue> ClientCircuit::prepareInput(Value arg, size_t pos) {
|
||||
if (pos >= inputTransformers.size()) {
|
||||
return StringError("Tried to prepare a Value for incorrect position.");
|
||||
}
|
||||
return inputTransformers[pos](arg);
|
||||
}
|
||||
|
||||
Result<Value> ClientCircuit::processOutput(TransportValue result, size_t pos) {
|
||||
if (pos >= outputTransformers.size()) {
|
||||
return StringError(
|
||||
"Tried to process a TransportValue for incorrect position.");
|
||||
}
|
||||
return outputTransformers[pos](result);
|
||||
}
|
||||
|
||||
std::string ClientCircuit::getName() {
|
||||
return circuitInfo.asReader().getName();
|
||||
}
|
||||
|
||||
const Message<concreteprotocol::CircuitInfo> &ClientCircuit::getCircuitInfo() {
|
||||
return circuitInfo;
|
||||
}
|
||||
|
||||
Result<ClientProgram>
|
||||
ClientProgram::create(const Message<concreteprotocol::ProgramInfo> &info,
|
||||
const ClientKeyset &keyset,
|
||||
std::shared_ptr<CSPRNG> csprng, bool useSimulation) {
|
||||
ClientProgram output;
|
||||
for (auto circuitInfo : info.asReader().getCircuits()) {
|
||||
OUTCOME_TRY(
|
||||
ClientCircuit clientCircuit,
|
||||
ClientCircuit::create(circuitInfo, keyset, csprng, useSimulation));
|
||||
output.circuits.push_back(clientCircuit);
|
||||
}
|
||||
return output;
|
||||
}
|
||||
|
||||
Result<ClientCircuit> ClientProgram::getClientCircuit(std::string circuitName) {
|
||||
for (auto circuit : circuits) {
|
||||
if (circuit.getName() == circuitName) {
|
||||
return circuit;
|
||||
}
|
||||
}
|
||||
return StringError("Tried to get unknown client circuit: `" + circuitName +
|
||||
"`");
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -1,260 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "boost/outcome.h"
|
||||
#include "llvm/ADT/Hashing.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
// https://stackoverflow.com/a/38140932
|
||||
static inline void hash_(std::size_t &seed) {}
|
||||
template <typename T, typename... Rest>
|
||||
static inline void hash_(std::size_t &seed, const T &v, Rest... rest) {
|
||||
// See https://softwareengineering.stackexchange.com/a/402543
|
||||
const auto GOLDEN_RATIO = 0x9e3779b97f4a7c15; // pseudo random bits
|
||||
seed ^= llvm::hash_value(v) + GOLDEN_RATIO + (seed << 6) + (seed >> 2);
|
||||
hash_(seed, rest...);
|
||||
}
|
||||
|
||||
static long double_to_bits(double &v) { return *reinterpret_cast<long *>(&v); }
|
||||
|
||||
void LweSecretKeyParam::hash(size_t &seed) { hash_(seed, dimension); }
|
||||
|
||||
void BootstrapKeyParam::hash(size_t &seed) {
|
||||
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
|
||||
glweDimension, double_to_bits(variance));
|
||||
}
|
||||
|
||||
void KeyswitchKeyParam::hash(size_t &seed) {
|
||||
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
|
||||
double_to_bits(variance));
|
||||
}
|
||||
|
||||
void PackingKeyswitchKeyParam::hash(size_t &seed) {
|
||||
hash_(seed, inputSecretKeyID, outputSecretKeyID, level, baseLog,
|
||||
glweDimension, polynomialSize, inputLweDimension,
|
||||
double_to_bits(variance));
|
||||
}
|
||||
|
||||
std::size_t ClientParameters::hash() {
|
||||
std::size_t currentHash = 1;
|
||||
for (auto secretKeyParam : secretKeys) {
|
||||
secretKeyParam.hash(currentHash);
|
||||
}
|
||||
for (auto bootstrapKeyParam : bootstrapKeys) {
|
||||
bootstrapKeyParam.hash(currentHash);
|
||||
}
|
||||
for (auto keyswitchParam : keyswitchKeys) {
|
||||
keyswitchParam.hash(currentHash);
|
||||
}
|
||||
for (auto packingKeyswitchKeyParam : packingKeyswitchKeys) {
|
||||
packingKeyswitchKeyParam.hash(currentHash);
|
||||
}
|
||||
return currentHash;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const LweSecretKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"dimension", v.dimension},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, LweSecretKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("dimension", v.dimension);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const BootstrapKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"inputSecretKeyID", v.inputSecretKeyID},
|
||||
{"outputSecretKeyID", v.outputSecretKeyID},
|
||||
{"level", v.level},
|
||||
{"glweDimension", v.glweDimension},
|
||||
{"baseLog", v.baseLog},
|
||||
{"variance", v.variance},
|
||||
{"polynomialSize", v.polynomialSize},
|
||||
{"inputLweDimension", v.inputLweDimension},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value j, BootstrapKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
|
||||
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
|
||||
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
|
||||
O.map("glweDimension", v.glweDimension) &&
|
||||
O.map("variance", v.variance) &&
|
||||
O.map("polynomialSize", v.polynomialSize) &&
|
||||
O.map("inputLweDimension", v.inputLweDimension);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const KeyswitchKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"inputSecretKeyID", v.inputSecretKeyID},
|
||||
{"outputSecretKeyID", v.outputSecretKeyID},
|
||||
{"level", v.level},
|
||||
{"baseLog", v.baseLog},
|
||||
{"variance", v.variance},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, KeyswitchKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
|
||||
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
|
||||
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
|
||||
O.map("variance", v.variance);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const PackingKeyswitchKeyParam &v) {
|
||||
llvm::json::Object object{
|
||||
{"inputSecretKeyID", v.inputSecretKeyID},
|
||||
{"outputSecretKeyID", v.outputSecretKeyID},
|
||||
{"level", v.level},
|
||||
{"baseLog", v.baseLog},
|
||||
{"glweDimension", v.glweDimension},
|
||||
{"polynomialSize", v.polynomialSize},
|
||||
{"inputLweDimension", v.inputLweDimension},
|
||||
{"variance", v.variance},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, PackingKeyswitchKeyParam &v,
|
||||
llvm::json::Path p) {
|
||||
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("inputSecretKeyID", v.inputSecretKeyID) &&
|
||||
O.map("outputSecretKeyID", v.outputSecretKeyID) &&
|
||||
O.map("level", v.level) && O.map("baseLog", v.baseLog) &&
|
||||
O.map("glweDimension", v.glweDimension) &&
|
||||
O.map("polynomialSize", v.polynomialSize) &&
|
||||
O.map("inputLweDimension", v.inputLweDimension) &&
|
||||
O.map("variance", v.variance);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGateShape &v) {
|
||||
llvm::json::Object object{
|
||||
{"width", v.width},
|
||||
{"dimensions", v.dimensions},
|
||||
{"size", v.size},
|
||||
{"sign", v.sign},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, CircuitGateShape &v,
|
||||
llvm::json::Path p) {
|
||||
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("width", v.width) && O.map("size", v.size) &&
|
||||
O.map("dimensions", v.dimensions) && O.map("sign", v.sign);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const Encoding &v) {
|
||||
llvm::json::Object object{
|
||||
{"precision", v.precision},
|
||||
{"isSigned", v.isSigned},
|
||||
};
|
||||
if (!v.crt.empty()) {
|
||||
object.insert({"crt", v.crt});
|
||||
}
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, Encoding &v, llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
if (!(O && O.map("precision", v.precision) &&
|
||||
O.map("isSigned", v.isSigned))) {
|
||||
return false;
|
||||
}
|
||||
// TODO: check this is correct for an optional field
|
||||
O.map("crt", v.crt);
|
||||
return true;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const EncryptionGate &v) {
|
||||
llvm::json::Object object{
|
||||
{"secretKeyID", v.secretKeyID},
|
||||
{"variance", v.variance},
|
||||
{"encoding", v.encoding},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, EncryptionGate &v,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("secretKeyID", v.secretKeyID) &&
|
||||
O.map("variance", v.variance) && O.map("encoding", v.encoding);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const CircuitGate &v) {
|
||||
llvm::json::Object object{
|
||||
{"encryption", v.encryption},
|
||||
{"shape", v.shape},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, CircuitGate &v, llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("encryption", v.encryption) && O.map("shape", v.shape);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const ClientParameters &v) {
|
||||
llvm::json::Object object{
|
||||
{"secretKeys", v.secretKeys},
|
||||
{"bootstrapKeys", v.bootstrapKeys},
|
||||
{"keyswitchKeys", v.keyswitchKeys},
|
||||
{"packingKeyswitchKeys", v.packingKeyswitchKeys},
|
||||
{"inputs", v.inputs},
|
||||
{"outputs", v.outputs},
|
||||
{"functionName", v.functionName},
|
||||
};
|
||||
return object;
|
||||
}
|
||||
bool fromJSON(const llvm::json::Value j, ClientParameters &v,
|
||||
llvm::json::Path p) {
|
||||
llvm::json::ObjectMapper O(j, p);
|
||||
return O && O.map("secretKeys", v.secretKeys) &&
|
||||
O.map("bootstrapKeys", v.bootstrapKeys) &&
|
||||
O.map("keyswitchKeys", v.keyswitchKeys) &&
|
||||
O.map("packingKeyswitchKeys", v.packingKeyswitchKeys) &&
|
||||
O.map("inputs", v.inputs) && O.map("outputs", v.outputs) &&
|
||||
O.map("functionName", v.functionName);
|
||||
}
|
||||
|
||||
std::string ClientParameters::getClientParametersPath(std::string path) {
|
||||
return path + CLIENT_PARAMETERS_EXT;
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<ClientParameters>, StringError>
|
||||
ClientParameters::load(std::string jsonPath) {
|
||||
std::ifstream file(jsonPath);
|
||||
std::string content((std::istreambuf_iterator<char>(file)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
if (file.fail()) {
|
||||
return StringError("Cannot read file: ") << jsonPath;
|
||||
}
|
||||
auto expectedClientParams =
|
||||
llvm::json::parse<std::vector<ClientParameters>>(content);
|
||||
if (auto err = expectedClientParams.takeError()) {
|
||||
return StringError("Cannot open client parameters: ")
|
||||
<< llvm::toString(std::move(err)) << "\n"
|
||||
<< content << "\n";
|
||||
}
|
||||
return expectedClientParams.get();
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -1,63 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
|
||||
auto sharedValues = std::vector<SharedScalarOrTensorData>();
|
||||
sharedValues.reserve(this->values.size());
|
||||
|
||||
for (auto &&value : this->values) {
|
||||
sharedValues.push_back(SharedScalarOrTensorData(std::move(value)));
|
||||
}
|
||||
|
||||
return std::make_unique<PublicArguments>(clientParameters, sharedValues);
|
||||
}
|
||||
|
||||
/// Split the input integer into `size` chunks of `chunkWidth` bits each
|
||||
std::vector<uint64_t> chunkInput(uint64_t value, size_t size,
|
||||
unsigned int chunkWidth) {
|
||||
std::vector<uint64_t> chunks;
|
||||
chunks.reserve(size);
|
||||
uint64_t mask = (1 << chunkWidth) - 1;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
auto chunk = value & mask;
|
||||
chunks.push_back((uint64_t)chunk);
|
||||
value >>= chunkWidth;
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> checkSizes(size_t actualSize,
|
||||
size_t expectedSize) {
|
||||
if (actualSize == expectedSize) {
|
||||
return outcome::success();
|
||||
}
|
||||
return StringError("function expects ")
|
||||
<< expectedSize << " arguments but has been called with " << actualSize
|
||||
<< " arguments";
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::checkAllArgs(KeySet &keySet) {
|
||||
size_t arity = keySet.numInputs();
|
||||
return checkSizes(values.size(), arity);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArguments::checkAllArgs(ClientParameters ¶ms) {
|
||||
size_t arity = params.inputs.size();
|
||||
return checkSizes(values.size(), arity);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -1,157 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concrete-cpu.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
inline void getApproval() {
|
||||
std::cerr << "DANGER: You are generating an empty unsecure secret keys. "
|
||||
"Enter \"y\" to continue: ";
|
||||
char answer;
|
||||
std::cin >> answer;
|
||||
if (answer != 'y') {
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
ConcreteCSPRNG::ConcreteCSPRNG(__uint128_t seed)
|
||||
: CSPRNG(nullptr, &CONCRETE_CSPRNG_VTABLE) {
|
||||
ptr = (Csprng *)aligned_alloc(CONCRETE_CSPRNG_ALIGN, CONCRETE_CSPRNG_SIZE);
|
||||
struct Uint128 u128;
|
||||
if (seed == 0) {
|
||||
switch (concrete_cpu_crypto_secure_random_128(&u128)) {
|
||||
case 1:
|
||||
break;
|
||||
case -1:
|
||||
llvm::errs()
|
||||
<< "WARNING: The generated random seed is not crypto secure\n";
|
||||
break;
|
||||
default:
|
||||
assert(false && "Cannot instantiate a random seed");
|
||||
}
|
||||
|
||||
} else {
|
||||
for (int i = 0; i < 16; i++) {
|
||||
u128.little_endian_bytes[i] = seed >> (8 * i);
|
||||
}
|
||||
}
|
||||
concrete_cpu_construct_concrete_csprng(ptr, u128);
|
||||
}
|
||||
|
||||
ConcreteCSPRNG::ConcreteCSPRNG(ConcreteCSPRNG &&other)
|
||||
: CSPRNG(other.ptr, &CONCRETE_CSPRNG_VTABLE) {
|
||||
assert(ptr != nullptr);
|
||||
other.ptr = nullptr;
|
||||
}
|
||||
|
||||
ConcreteCSPRNG::~ConcreteCSPRNG() {
|
||||
if (ptr != nullptr) {
|
||||
concrete_cpu_destroy_concrete_csprng(ptr);
|
||||
free(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
LweSecretKey::LweSecretKey(LweSecretKeyParam ¶meters, CSPRNG &csprng)
|
||||
: _parameters(parameters) {
|
||||
// Allocate the buffer
|
||||
_buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
_buffer->resize(parameters.dimension);
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
// In insecure debug mode, the secret key is filled with zeros.
|
||||
getApproval();
|
||||
for (uint64_t &val : *_buffer) {
|
||||
val = 0;
|
||||
}
|
||||
#else
|
||||
// Initialize the lwe secret key buffer
|
||||
concrete_cpu_init_secret_key_u64(_buffer->data(), parameters.dimension,
|
||||
csprng.ptr, csprng.vtable);
|
||||
#endif
|
||||
}
|
||||
|
||||
void LweSecretKey::encrypt(uint64_t *ciphertext, uint64_t plaintext,
|
||||
double variance, CSPRNG &csprng) const {
|
||||
concrete_cpu_encrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext,
|
||||
plaintext, parameters().dimension,
|
||||
variance, csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
void LweSecretKey::decrypt(const uint64_t *ciphertext,
|
||||
uint64_t &plaintext) const {
|
||||
concrete_cpu_decrypt_lwe_ciphertext_u64(_buffer->data(), ciphertext,
|
||||
parameters().dimension, &plaintext);
|
||||
}
|
||||
|
||||
LweKeyswitchKey::LweKeyswitchKey(KeyswitchKeyParam ¶meters,
|
||||
LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng)
|
||||
: _parameters(parameters) {
|
||||
// Allocate the buffer
|
||||
auto size = concrete_cpu_keyswitch_key_size_u64(
|
||||
_parameters.level, _parameters.baseLog, inputKey.dimension(),
|
||||
outputKey.dimension());
|
||||
_buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
_buffer->resize(size);
|
||||
|
||||
// Initialize the keyswitch key buffer
|
||||
concrete_cpu_init_lwe_keyswitch_key_u64(
|
||||
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
|
||||
inputKey.dimension(), outputKey.dimension(), _parameters.level,
|
||||
_parameters.baseLog, _parameters.variance, csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
LweBootstrapKey::LweBootstrapKey(BootstrapKeyParam ¶meters,
|
||||
LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey, CSPRNG &csprng)
|
||||
: _parameters(parameters) {
|
||||
// TODO
|
||||
size_t polynomial_size = outputKey.dimension() / _parameters.glweDimension;
|
||||
// Allocate the buffer
|
||||
auto size = concrete_cpu_bootstrap_key_size_u64(
|
||||
_parameters.level, _parameters.glweDimension, polynomial_size,
|
||||
inputKey.dimension());
|
||||
_buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
_buffer->resize(size);
|
||||
|
||||
// Initialize the keyswitch key buffer
|
||||
concrete_cpu_init_lwe_bootstrap_key_u64(
|
||||
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
|
||||
inputKey.dimension(), polynomial_size, _parameters.glweDimension,
|
||||
_parameters.level, _parameters.baseLog, _parameters.variance,
|
||||
Parallelism::Rayon, csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
PackingKeyswitchKey::PackingKeyswitchKey(PackingKeyswitchKeyParam ¶ms,
|
||||
LweSecretKey &inputKey,
|
||||
LweSecretKey &outputKey,
|
||||
CSPRNG &csprng)
|
||||
: _parameters(params) {
|
||||
assert(_parameters.inputLweDimension == inputKey.dimension());
|
||||
assert(_parameters.glweDimension * _parameters.polynomialSize ==
|
||||
outputKey.dimension());
|
||||
|
||||
// Allocate the buffer
|
||||
auto size = concrete_cpu_lwe_packing_keyswitch_key_size(
|
||||
_parameters.glweDimension, _parameters.polynomialSize, _parameters.level,
|
||||
_parameters.inputLweDimension);
|
||||
_buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
_buffer->resize(size * (_parameters.glweDimension + 1));
|
||||
|
||||
// Initialize the keyswitch key buffer
|
||||
concrete_cpu_init_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
|
||||
_buffer->data(), inputKey.buffer(), outputKey.buffer(),
|
||||
_parameters.inputLweDimension, _parameters.polynomialSize,
|
||||
_parameters.glweDimension, _parameters.level, _parameters.baseLog,
|
||||
_parameters.variance, Parallelism::Rayon, csprng.ptr, csprng.vtable);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -1,303 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/CRT.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySet::generate(ClientParameters clientParameters, CSPRNG &&csprng) {
|
||||
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
|
||||
OUTCOME_TRYV(keySet->generateKeysFromParams());
|
||||
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError> KeySet::fromKeys(
|
||||
ClientParameters clientParameters, std::vector<LweSecretKey> secretKeys,
|
||||
std::vector<LweBootstrapKey> bootstrapKeys,
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys,
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys, CSPRNG &&csprng) {
|
||||
|
||||
auto keySet = std::make_unique<KeySet>(clientParameters, std::move(csprng));
|
||||
keySet->secretKeys = secretKeys;
|
||||
keySet->bootstrapKeys = bootstrapKeys;
|
||||
keySet->keyswitchKeys = keyswitchKeys;
|
||||
keySet->packingKeyswitchKeys = packingKeyswitchKeys;
|
||||
OUTCOME_TRYV(keySet->setupEncryptionMaterial());
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
EvaluationKeys KeySet::evaluationKeys() {
|
||||
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
|
||||
}
|
||||
|
||||
outcome::checked<KeySet::SecretKeyGateMapping, StringError>
|
||||
KeySet::mapCircuitGateLweSecretKey(std::vector<CircuitGate> gates) {
|
||||
SecretKeyGateMapping mapping;
|
||||
for (auto gate : gates) {
|
||||
if (gate.encryption.has_value()) {
|
||||
assert(gate.encryption->secretKeyID < this->secretKeys.size());
|
||||
auto skIt = this->secretKeys[gate.encryption->secretKeyID];
|
||||
|
||||
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {gate, skIt};
|
||||
mapping.push_back(input);
|
||||
} else {
|
||||
std::pair<CircuitGate, std::optional<LweSecretKey>> input = {
|
||||
gate, std::nullopt};
|
||||
mapping.push_back(input);
|
||||
}
|
||||
}
|
||||
return mapping;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> KeySet::setupEncryptionMaterial() {
|
||||
OUTCOME_TRY(this->inputs,
|
||||
mapCircuitGateLweSecretKey(_clientParameters.inputs));
|
||||
OUTCOME_TRY(this->outputs,
|
||||
mapCircuitGateLweSecretKey(_clientParameters.outputs));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> KeySet::generateKeysFromParams() {
|
||||
|
||||
// Generate LWE secret keys
|
||||
for (auto secretKeyParam : _clientParameters.secretKeys) {
|
||||
OUTCOME_TRYV(this->generateSecretKey(secretKeyParam));
|
||||
}
|
||||
// Generate bootstrap keys
|
||||
for (auto bootstrapKeyParam : _clientParameters.bootstrapKeys) {
|
||||
OUTCOME_TRYV(this->generateBootstrapKey(bootstrapKeyParam));
|
||||
}
|
||||
// Generate keyswitch key
|
||||
for (auto keyswitchParam : _clientParameters.keyswitchKeys) {
|
||||
OUTCOME_TRYV(this->generateKeyswitchKey(keyswitchParam));
|
||||
}
|
||||
// Generate packing keyswitch key
|
||||
for (auto packingKeyswitchKeyParam : _clientParameters.packingKeyswitchKeys) {
|
||||
OUTCOME_TRYV(this->generatePackingKeyswitchKey(packingKeyswitchKeyParam));
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateSecretKey(LweSecretKeyParam param) {
|
||||
// Init the lwe secret key
|
||||
LweSecretKey sk(param, csprng);
|
||||
// Store the lwe secret key
|
||||
secretKeys.push_back(sk);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<LweSecretKey, StringError>
|
||||
KeySet::findLweSecretKey(LweSecretKeyID keyID) {
|
||||
assert(keyID < secretKeys.size());
|
||||
auto secretKey = secretKeys[keyID];
|
||||
|
||||
return secretKey;
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateBootstrapKey(BootstrapKeyParam param) {
|
||||
// Finding input and output secretKeys
|
||||
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
|
||||
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
|
||||
// Initialize the bootstrap key
|
||||
LweBootstrapKey bootstrapKey(param, inputKey, outputKey, csprng);
|
||||
// Store the bootstrap key
|
||||
bootstrapKeys.push_back(std::move(bootstrapKey));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generateKeyswitchKey(KeyswitchKeyParam param) {
|
||||
// Finding input and output secretKeys
|
||||
OUTCOME_TRY(auto inputKey, findLweSecretKey(param.inputSecretKeyID));
|
||||
OUTCOME_TRY(auto outputKey, findLweSecretKey(param.outputSecretKeyID));
|
||||
// Initialize the bootstrap key
|
||||
LweKeyswitchKey keyswitchKey(param, inputKey, outputKey, csprng);
|
||||
// Store the keyswitch key
|
||||
keyswitchKeys.push_back(keyswitchKey);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::generatePackingKeyswitchKey(PackingKeyswitchKeyParam param) {
|
||||
// Finding input secretKeys
|
||||
assert(param.inputSecretKeyID < secretKeys.size());
|
||||
auto inputSk = secretKeys[param.inputSecretKeyID];
|
||||
|
||||
assert(param.outputSecretKeyID < secretKeys.size());
|
||||
auto outputSk = secretKeys[param.outputSecretKeyID];
|
||||
|
||||
PackingKeyswitchKey packingKeyswitchKey(param, inputSk, outputSk, csprng);
|
||||
// Store the keyswitch key
|
||||
packingKeyswitchKeys.push_back(packingKeyswitchKey);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext, uint64_t &size) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return StringError("allocate_lwe position of argument is too high");
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
auto encryption = std::get<0>(inputSk).encryption;
|
||||
if (!encryption.has_value()) {
|
||||
return StringError("allocate_lwe argument #")
|
||||
<< argPos << "is not encypeted";
|
||||
}
|
||||
auto numBlocks =
|
||||
encryption->encoding.crt.empty() ? 1 : encryption->encoding.crt.size();
|
||||
assert(inputSk.second.has_value());
|
||||
|
||||
size = inputSk.second->parameters().lweSize();
|
||||
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size * numBlocks);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
bool KeySet::isInputEncrypted(size_t argPos) {
|
||||
return argPos < inputs.size() &&
|
||||
std::get<0>(inputs[argPos]).encryption.has_value();
|
||||
}
|
||||
|
||||
bool KeySet::isOutputEncrypted(size_t argPos) {
|
||||
return argPos < outputs.size() &&
|
||||
std::get<0>(outputs[argPos]).encryption.has_value();
|
||||
}
|
||||
|
||||
/// Return the number of bits to represents the given value
|
||||
uint64_t bitWidthOfValue(uint64_t value) { return std::ceil(std::log2(value)); }
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return StringError("encrypt_lwe position of argument is too high");
|
||||
}
|
||||
const auto &inputSk = inputs[argPos];
|
||||
auto encryption = std::get<0>(inputSk).encryption;
|
||||
if (!encryption.has_value()) {
|
||||
return StringError("encrypt_lwe the positional argument is not encrypted");
|
||||
}
|
||||
auto encoding = encryption->encoding;
|
||||
assert(inputSk.second.has_value());
|
||||
auto lweSecretKey = *inputSk.second;
|
||||
auto lweSecretKeyParam = lweSecretKey.parameters();
|
||||
// CRT encoding - N blocks with crt encoding
|
||||
auto crt = encryption->encoding.crt;
|
||||
if (!crt.empty()) {
|
||||
// Put each decomposition into a new ciphertext
|
||||
auto product = crt::productOfModuli(crt);
|
||||
for (auto modulus : crt) {
|
||||
auto plaintext = crt::encode(input, modulus, product);
|
||||
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
|
||||
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
// Simple TFHE integers - 1 blocks with one padding bits
|
||||
// TODO we could check if the input value is in the right range
|
||||
uint64_t plaintext = input << (64 - (encryption->encoding.precision + 1));
|
||||
lweSecretKey.encrypt(ciphertext, plaintext, encryption->variance, csprng);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
if (argPos >= outputs.size()) {
|
||||
return StringError("decrypt_lwe: position of argument is too high");
|
||||
}
|
||||
auto outputSk = outputs[argPos];
|
||||
assert(outputSk.second.has_value());
|
||||
auto lweSecretKey = *outputSk.second;
|
||||
auto lweSecretKeyParam = lweSecretKey.parameters();
|
||||
auto encryption = std::get<0>(outputSk).encryption;
|
||||
if (!encryption.has_value()) {
|
||||
return StringError("decrypt_lwe: the positional argument is not encrypted");
|
||||
}
|
||||
|
||||
auto crt = encryption->encoding.crt;
|
||||
|
||||
if (!crt.empty()) {
|
||||
// CRT encoded TFHE integers
|
||||
|
||||
// Decrypt and decode remainders
|
||||
std::vector<int64_t> remainders;
|
||||
for (auto modulus : crt) {
|
||||
uint64_t decrypted = 0;
|
||||
lweSecretKey.decrypt(ciphertext, decrypted);
|
||||
|
||||
auto plaintext = crt::decode(decrypted, modulus);
|
||||
remainders.push_back(plaintext);
|
||||
ciphertext = ciphertext + lweSecretKeyParam.lweSize();
|
||||
}
|
||||
|
||||
// Compute the inverse crt
|
||||
output = crt::iCrt(crt, remainders);
|
||||
|
||||
// Further decode signed integers
|
||||
if (encryption->encoding.isSigned) {
|
||||
uint64_t maxPos = 1;
|
||||
for (auto prime : encryption->encoding.crt) {
|
||||
maxPos *= prime;
|
||||
}
|
||||
maxPos /= 2;
|
||||
if (output >= maxPos) {
|
||||
output -= maxPos * 2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Native encoded TFHE integers - 1 blocks with one padding bits
|
||||
uint64_t plaintext = 0;
|
||||
lweSecretKey.decrypt(ciphertext, plaintext);
|
||||
|
||||
// Decode unsigned integer
|
||||
uint64_t precision = encryption->encoding.precision;
|
||||
output = plaintext >> (64 - precision - 2);
|
||||
auto carry = output % 2;
|
||||
uint64_t mod = (((uint64_t)1) << (precision + 1));
|
||||
output = ((output >> 1) + carry) % mod;
|
||||
|
||||
// Further decode signed integers.
|
||||
if (encryption->encoding.isSigned) {
|
||||
uint64_t maxPos = (((uint64_t)1) << (precision - 1));
|
||||
if (output >= maxPos) { // The output is actually negative.
|
||||
// Set the preceding bits to zero
|
||||
output |= UINT64_MAX << precision;
|
||||
// This makes sure when the value is cast to int64, it has the correct
|
||||
// value
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
const std::vector<LweSecretKey> &KeySet::getSecretKeys() const {
|
||||
return secretKeys;
|
||||
}
|
||||
|
||||
const std::vector<LweBootstrapKey> &KeySet::getBootstrapKeys() const {
|
||||
return bootstrapKeys;
|
||||
}
|
||||
|
||||
const std::vector<LweKeyswitchKey> &KeySet::getKeyswitchKeys() const {
|
||||
return keyswitchKeys;
|
||||
}
|
||||
|
||||
const std::vector<PackingKeyswitchKey> &
|
||||
KeySet::getPackingKeyswitchKeys() const {
|
||||
return packingKeyswitchKeys;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -1,292 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <utime.h>
|
||||
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
inline void getApproval() {
|
||||
std::cerr << "DANGER: You are using an empty unsecure secret keys. Enter "
|
||||
"\"y\" to continue: ";
|
||||
char answer;
|
||||
std::cin >> answer;
|
||||
if (answer != 'y') {
|
||||
std::abort();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
template <class Key>
|
||||
outcome::checked<Key, StringError> loadKey(llvm::SmallString<0> &path,
|
||||
Key(deser)(std::istream &istream)) {
|
||||
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
getApproval();
|
||||
#endif
|
||||
std::ifstream in((std::string)path, std::ofstream::binary);
|
||||
if (in.fail()) {
|
||||
return StringError("Cannot access " + (std::string)path);
|
||||
}
|
||||
auto key = deser(in);
|
||||
if (in.bad()) {
|
||||
return StringError("Cannot load key at path(") << (std::string)path << ")";
|
||||
}
|
||||
return key;
|
||||
}
|
||||
|
||||
template <class Key>
|
||||
outcome::checked<void, StringError> saveKey(llvm::SmallString<0> &path,
|
||||
Key key) {
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
getApproval();
|
||||
#endif
|
||||
std::ofstream out((std::string)path, std::ofstream::binary);
|
||||
if (out.fail()) {
|
||||
return StringError("Cannot access " + (std::string)path);
|
||||
}
|
||||
out << key;
|
||||
if (out.bad()) {
|
||||
return StringError("Cannot save key at path(") << (std::string)path << ")";
|
||||
}
|
||||
out.close();
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::loadKeys(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb, std::string folderPath) {
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
getApproval();
|
||||
#endif
|
||||
|
||||
// Mark the folder as recently use.
|
||||
// e.g. so the CI can do some cleanup of unused keys.
|
||||
utime(folderPath.c_str(), nullptr);
|
||||
|
||||
std::vector<LweSecretKey> secretKeys;
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
|
||||
// Load secret keys
|
||||
for (auto p : llvm::enumerate(params.secretKeys)) {
|
||||
// TODO - Check parameters?
|
||||
// auto param = secretKeyParam.second;
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRY(auto key, loadKey(path, readLweSecretKey));
|
||||
secretKeys.push_back(key);
|
||||
}
|
||||
// Load bootstrap keys
|
||||
for (auto p : llvm::enumerate(params.bootstrapKeys)) {
|
||||
// TODO - Check parameters?
|
||||
// auto param = p.value();
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRY(auto key, loadKey(path, readLweBootstrapKey));
|
||||
bootstrapKeys.push_back(key);
|
||||
}
|
||||
// Load keyswitch keys
|
||||
for (auto p : llvm::enumerate(params.keyswitchKeys)) {
|
||||
// TODO - Check parameters?
|
||||
// auto param = p.value();
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRY(auto key, loadKey(path, readLweKeyswitchKey));
|
||||
keyswitchKeys.push_back(key);
|
||||
}
|
||||
|
||||
for (auto p : llvm::enumerate(params.packingKeyswitchKeys)) {
|
||||
// TODO - Check parameters?
|
||||
// auto param = p.value();
|
||||
llvm::SmallString<0> path(folderPath);
|
||||
llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRY(auto key, loadKey(path, readPackingKeyswitchKey));
|
||||
packingKeyswitchKeys.push_back(key);
|
||||
}
|
||||
|
||||
__uint128_t seed = seed_msb;
|
||||
seed <<= 64;
|
||||
seed += seed_lsb;
|
||||
|
||||
auto csprng = ConcreteCSPRNG(seed);
|
||||
|
||||
OUTCOME_TRY(auto keySet,
|
||||
KeySet::fromKeys(params, secretKeys, bootstrapKeys, keyswitchKeys,
|
||||
packingKeyswitchKeys, std::move(csprng)));
|
||||
|
||||
return std::move(keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> saveKeys(KeySet &key_set,
|
||||
llvm::SmallString<0> &folderPath) {
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
getApproval();
|
||||
#endif
|
||||
|
||||
llvm::SmallString<0> folderIncompletePath = folderPath;
|
||||
|
||||
folderIncompletePath.append(".incomplete");
|
||||
|
||||
auto err = llvm::sys::fs::create_directories(folderIncompletePath);
|
||||
if (err) {
|
||||
return StringError("Cannot create directory \"")
|
||||
<< std::string(folderIncompletePath) << "\": " << err.message();
|
||||
}
|
||||
// Save LWE secret keys
|
||||
for (auto p : llvm::enumerate(key_set.getSecretKeys())) {
|
||||
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "secretKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRYV(saveKey(path, p.value()));
|
||||
}
|
||||
// Save bootstrap keys
|
||||
for (auto p : llvm::enumerate(key_set.getBootstrapKeys())) {
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "pbsKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRYV(saveKey(path, p.value()));
|
||||
}
|
||||
// Save keyswitch keys
|
||||
for (auto p : llvm::enumerate(key_set.getKeyswitchKeys())) {
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "ksKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRYV(saveKey(path, p.value()));
|
||||
}
|
||||
// Save packing keyswitch keys
|
||||
for (auto p : llvm::enumerate(key_set.getPackingKeyswitchKeys())) {
|
||||
llvm::SmallString<0> path = folderIncompletePath;
|
||||
llvm::sys::path::append(path, "pksKey_" + std::to_string(p.index()));
|
||||
OUTCOME_TRYV(saveKey(path, p.value()));
|
||||
}
|
||||
|
||||
err = llvm::sys::fs::rename(folderIncompletePath, folderPath);
|
||||
if (err) {
|
||||
llvm::sys::fs::remove_directories(folderIncompletePath);
|
||||
}
|
||||
if (!llvm::sys::fs::exists(folderPath)) {
|
||||
return StringError("Cannot save directory \"")
|
||||
<< std::string(folderPath) << "\"";
|
||||
}
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::loadOrGenerateSave(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
getApproval();
|
||||
#endif
|
||||
|
||||
llvm::SmallString<0> folderPath =
|
||||
llvm::SmallString<0>(this->backingDirectoryPath);
|
||||
|
||||
llvm::sys::path::append(folderPath, std::to_string(params.hash()));
|
||||
|
||||
llvm::sys::path::append(folderPath, std::to_string(seed_msb) + "_" +
|
||||
std::to_string(seed_lsb));
|
||||
|
||||
// Creating a lock for concurrent generation
|
||||
llvm::SmallString<0> lockPath(folderPath);
|
||||
lockPath.append("lock");
|
||||
int FD_lock;
|
||||
llvm::sys::fs::create_directories(llvm::sys::path::parent_path(lockPath));
|
||||
// Open or create the lock file
|
||||
auto err = llvm::sys::fs::openFile(
|
||||
lockPath, FD_lock, llvm::sys::fs::CreationDisposition::CD_OpenAlways,
|
||||
llvm::sys::fs::FileAccess::FA_Write, llvm::sys::fs::OpenFlags::OF_None);
|
||||
|
||||
if (err) {
|
||||
// parent does not exists OR right issue (creation or write)
|
||||
return StringError("Cannot access \"")
|
||||
<< std::string(lockPath) << "\": " << err.message();
|
||||
}
|
||||
|
||||
// The lock is released when the function returns.
|
||||
// => any intermediate state in the function is not visible to others.
|
||||
auto unlockAtReturn = llvm::make_scope_exit([&]() {
|
||||
llvm::sys::fs::closeFile(FD_lock);
|
||||
llvm::sys::fs::unlockFile(FD_lock);
|
||||
llvm::sys::fs::remove(lockPath);
|
||||
});
|
||||
llvm::sys::fs::lockFile(FD_lock);
|
||||
|
||||
if (llvm::sys::fs::exists(folderPath)) {
|
||||
// Once it has been generated by another process (or was alread here)
|
||||
auto keys = loadKeys(params, seed_msb, seed_lsb, std::string(folderPath));
|
||||
if (keys.has_value()) {
|
||||
return keys;
|
||||
} else {
|
||||
std::cerr << std::string(keys.error().mesg) << "\n";
|
||||
std::cerr << "Invalid KeySetCache entry " << std::string(folderPath)
|
||||
<< "\n";
|
||||
llvm::sys::fs::remove_directories(folderPath);
|
||||
// Then we can continue as it didn't exist
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "KeySetCache: miss, regenerating " << std::string(folderPath)
|
||||
<< "\n";
|
||||
|
||||
__uint128_t seed = seed_msb;
|
||||
seed <<= 64;
|
||||
seed += seed_lsb;
|
||||
|
||||
auto csprng = ConcreteCSPRNG(seed);
|
||||
|
||||
OUTCOME_TRY(auto key_set, KeySet::generate(params, std::move(csprng)));
|
||||
|
||||
OUTCOME_TRYV(saveKeys(*key_set, folderPath));
|
||||
|
||||
return std::move(key_set);
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::generate(std::shared_ptr<KeySetCache> cache,
|
||||
ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
getApproval();
|
||||
#endif
|
||||
|
||||
__uint128_t seed = seed_msb;
|
||||
seed <<= 64;
|
||||
seed += seed_lsb;
|
||||
|
||||
auto csprng = ConcreteCSPRNG(seed);
|
||||
return cache ? cache->loadOrGenerateSave(params, seed_msb, seed_lsb)
|
||||
: KeySet::generate(params, std::move(csprng));
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
#ifdef CONCRETELANG_GENERATE_UNSECURE_SECRET_KEYS
|
||||
getApproval();
|
||||
#endif
|
||||
|
||||
return loadOrGenerateSave(params, seed_msb, seed_lsb);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -1,172 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
// TODO: optimize the move
|
||||
PublicArguments::PublicArguments(
|
||||
const ClientParameters &clientParameters,
|
||||
std::vector<clientlib::SharedScalarOrTensorData> &buffers)
|
||||
: clientParameters(clientParameters) {
|
||||
arguments = buffers;
|
||||
}
|
||||
|
||||
PublicArguments::~PublicArguments() {}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicArguments::serialize(std::ostream &ostream) {
|
||||
if (incorrectMode(ostream)) {
|
||||
return StringError(
|
||||
"PublicArguments::serialize: ostream should be in binary mode");
|
||||
}
|
||||
serializeVectorOfScalarOrTensorData(arguments, ostream);
|
||||
if (ostream.bad()) {
|
||||
return StringError(
|
||||
"PublicArguments::serialize: cannot serialize public arguments");
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicArguments::unserializeArgs(std::istream &istream) {
|
||||
OUTCOME_TRY(arguments, unserializeVectorOfScalarOrTensorData(istream));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
PublicArguments::unserialize(const ClientParameters &expectedParams,
|
||||
std::istream &istream) {
|
||||
std::vector<SharedScalarOrTensorData> emptyBuffers;
|
||||
auto sArguments =
|
||||
std::make_unique<PublicArguments>(expectedParams, emptyBuffers);
|
||||
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
|
||||
return std::move(sArguments);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicResult::unserialize(std::istream &istream) {
|
||||
OUTCOME_TRY(buffers, unserializeVectorOfScalarOrTensorData(istream));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicResult::serialize(std::ostream &ostream) {
|
||||
serializeVectorOfScalarOrTensorData(buffers, ostream);
|
||||
if (ostream.bad()) {
|
||||
return StringError("PublicResult::serialize: cannot serialize");
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
void next_coord_index(size_t index[], size_t sizes[], size_t rank) {
|
||||
// increase multi dim index
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
if (index[r] < sizes[r] - 1) {
|
||||
index[r]++;
|
||||
return;
|
||||
}
|
||||
index[r] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
size_t global_index(size_t index[], size_t sizes[], size_t strides[],
|
||||
size_t rank) {
|
||||
// compute global index from multi dim index
|
||||
size_t g_index = 0;
|
||||
size_t default_stride = 1;
|
||||
for (int r = rank - 1; r >= 0; r--) {
|
||||
g_index += index[r] * ((strides[r] == 0) ? default_stride : strides[r]);
|
||||
default_stride *= sizes[r];
|
||||
}
|
||||
return g_index;
|
||||
}
|
||||
|
||||
static inline bool isReferenceToMLIRGlobalMemory(void *ptr) {
|
||||
return reinterpret_cast<uintptr_t>(ptr) == 0xdeadbeef;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TensorData tensorDataFromMemRefTyped(size_t memref_rank, void *allocatedVoid,
|
||||
void *alignedVoid, size_t offset,
|
||||
size_t *sizes, size_t *strides) {
|
||||
T *allocated = reinterpret_cast<T *>(allocatedVoid);
|
||||
T *aligned = reinterpret_cast<T *>(alignedVoid);
|
||||
|
||||
TensorData result(llvm::ArrayRef<size_t>{sizes, memref_rank}, sizeof(T) * 8,
|
||||
std::is_signed<T>());
|
||||
assert(aligned != nullptr);
|
||||
|
||||
// ephemeral multi dim index to compute global strides
|
||||
size_t *index = new size_t[memref_rank];
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
index[r] = 0;
|
||||
}
|
||||
auto len = result.length();
|
||||
|
||||
// TODO: add a fast path for dense result (no real strides)
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
int g_index = offset + global_index(index, sizes, strides, memref_rank);
|
||||
result.getElementReference<T>(i) = aligned[g_index];
|
||||
next_coord_index(index, sizes, memref_rank);
|
||||
}
|
||||
delete[] index;
|
||||
// TEMPORARY: That quick and dirty but as this function is used only to
|
||||
// convert a result of the mlir program and as data are copied here, we
|
||||
// release the alocated pointer if it set.
|
||||
|
||||
if (allocated != nullptr && !isReferenceToMLIRGlobalMemory(allocated)) {
|
||||
free(allocated);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
|
||||
bool is_signed, void *allocated, void *aligned,
|
||||
size_t offset, size_t *sizes, size_t *strides) {
|
||||
ElementType et = getElementTypeFromWidthAndSign(element_width, is_signed);
|
||||
|
||||
switch (et) {
|
||||
case ElementType::i64:
|
||||
return tensorDataFromMemRefTyped<int64_t>(memref_rank, allocated, aligned,
|
||||
offset, sizes, strides);
|
||||
case ElementType::u64:
|
||||
return tensorDataFromMemRefTyped<uint64_t>(memref_rank, allocated, aligned,
|
||||
offset, sizes, strides);
|
||||
case ElementType::i32:
|
||||
return tensorDataFromMemRefTyped<int32_t>(memref_rank, allocated, aligned,
|
||||
offset, sizes, strides);
|
||||
case ElementType::u32:
|
||||
return tensorDataFromMemRefTyped<uint32_t>(memref_rank, allocated, aligned,
|
||||
offset, sizes, strides);
|
||||
case ElementType::i16:
|
||||
return tensorDataFromMemRefTyped<int16_t>(memref_rank, allocated, aligned,
|
||||
offset, sizes, strides);
|
||||
case ElementType::u16:
|
||||
return tensorDataFromMemRefTyped<uint16_t>(memref_rank, allocated, aligned,
|
||||
offset, sizes, strides);
|
||||
case ElementType::i8:
|
||||
return tensorDataFromMemRefTyped<int8_t>(memref_rank, allocated, aligned,
|
||||
offset, sizes, strides);
|
||||
case ElementType::u8:
|
||||
return tensorDataFromMemRefTyped<uint8_t>(memref_rank, allocated, aligned,
|
||||
offset, sizes, strides);
|
||||
}
|
||||
|
||||
// Cannot happen
|
||||
assert(false);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -1,585 +0,0 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <iosfwd>
|
||||
#include <iostream>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
template <typename Key>
|
||||
std::ostream &writeUInt64KeyBuffer(std::ostream &ostream, Key &buffer) {
|
||||
writeSize(ostream, (uint64_t)buffer.size());
|
||||
ostream.write((const char *)buffer.buffer(),
|
||||
buffer.size() * sizeof(uint64_t));
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
std::shared_ptr<std::vector<uint64_t>> &vec) {
|
||||
// TODO assertion on size?
|
||||
uint64_t size;
|
||||
readSize(istream, size);
|
||||
vec->resize(size);
|
||||
istream.read((char *)vec->data(), size * sizeof(uint64_t));
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
// LweSecretKey ////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKeyParam param) {
|
||||
writeWord(ostream, param.dimension);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, LweSecretKeyParam ¶m) {
|
||||
readWord(istream, param.dimension);
|
||||
return istream;
|
||||
}
|
||||
|
||||
// LweSecretKey /////////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &key) {
|
||||
ostream << key.parameters();
|
||||
writeUInt64KeyBuffer(ostream, key);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
LweSecretKey readLweSecretKey(std::istream &istream) {
|
||||
LweSecretKeyParam param;
|
||||
istream >> param;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
istream >> buffer;
|
||||
return LweSecretKey(buffer, param);
|
||||
}
|
||||
|
||||
// KeyswitchKeyParam ////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const KeyswitchKeyParam param) {
|
||||
// TODO keys id
|
||||
writeWord(ostream, param.level);
|
||||
writeWord(ostream, param.baseLog);
|
||||
writeWord(ostream, param.variance);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, KeyswitchKeyParam ¶m) {
|
||||
// TODO keys id
|
||||
param.outputSecretKeyID = 1234;
|
||||
param.inputSecretKeyID = 1234;
|
||||
readWord(istream, param.level);
|
||||
readWord(istream, param.baseLog);
|
||||
readWord(istream, param.variance);
|
||||
return istream;
|
||||
}
|
||||
|
||||
// LweKeyswitchKey //////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweKeyswitchKey &key) {
|
||||
ostream << key.parameters();
|
||||
writeUInt64KeyBuffer(ostream, key);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
LweKeyswitchKey readLweKeyswitchKey(std::istream &istream) {
|
||||
KeyswitchKeyParam param;
|
||||
istream >> param;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
istream >> buffer;
|
||||
return LweKeyswitchKey(buffer, param);
|
||||
}
|
||||
|
||||
// BootstrapKeyParam ////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const BootstrapKeyParam param) {
|
||||
// TODO keys id
|
||||
writeWord(ostream, param.level);
|
||||
writeWord(ostream, param.baseLog);
|
||||
writeWord(ostream, param.glweDimension);
|
||||
writeWord(ostream, param.variance);
|
||||
writeWord(ostream, param.polynomialSize);
|
||||
writeWord(ostream, param.inputLweDimension);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream, BootstrapKeyParam ¶m) {
|
||||
// TODO keys id
|
||||
readWord(istream, param.level);
|
||||
readWord(istream, param.baseLog);
|
||||
readWord(istream, param.glweDimension);
|
||||
readWord(istream, param.variance);
|
||||
readWord(istream, param.polynomialSize);
|
||||
readWord(istream, param.inputLweDimension);
|
||||
return istream;
|
||||
}
|
||||
|
||||
// LweBootstrapKey //////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweBootstrapKey &key) {
|
||||
ostream << key.parameters();
|
||||
writeUInt64KeyBuffer(ostream, key);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
LweBootstrapKey readLweBootstrapKey(std::istream &istream) {
|
||||
BootstrapKeyParam param;
|
||||
istream >> param;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
istream >> buffer;
|
||||
return LweBootstrapKey(buffer, param);
|
||||
}
|
||||
|
||||
// PackingKeyswitchKeyParam ////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const PackingKeyswitchKeyParam param) {
|
||||
|
||||
// TODO keys id
|
||||
writeWord(ostream, param.level);
|
||||
writeWord(ostream, param.baseLog);
|
||||
writeWord(ostream, param.glweDimension);
|
||||
writeWord(ostream, param.polynomialSize);
|
||||
writeWord(ostream, param.inputLweDimension);
|
||||
writeWord(ostream, param.variance);
|
||||
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::istream &operator>>(std::istream &istream,
|
||||
PackingKeyswitchKeyParam ¶m) {
|
||||
|
||||
// TODO keys id
|
||||
param.outputSecretKeyID = 1234;
|
||||
param.inputSecretKeyID = 1234;
|
||||
readWord(istream, param.level);
|
||||
readWord(istream, param.baseLog);
|
||||
readWord(istream, param.glweDimension);
|
||||
readWord(istream, param.polynomialSize);
|
||||
readWord(istream, param.inputLweDimension);
|
||||
readWord(istream, param.variance);
|
||||
|
||||
return istream;
|
||||
}
|
||||
|
||||
// PackingKeyswitchKey //////////////////////////////
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const PackingKeyswitchKey &key) {
|
||||
ostream << key.parameters();
|
||||
writeUInt64KeyBuffer(ostream, key);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
PackingKeyswitchKey readPackingKeyswitchKey(std::istream &istream) {
|
||||
PackingKeyswitchKeyParam param;
|
||||
istream >> param;
|
||||
auto buffer = std::make_shared<std::vector<uint64_t>>();
|
||||
istream >> buffer;
|
||||
auto b = PackingKeyswitchKey(buffer, param);
|
||||
|
||||
return b;
|
||||
}
|
||||
|
||||
// KeySet ////////////////////////////////
|
||||
|
||||
std::unique_ptr<KeySet> readKeySet(std::istream &istream) {
|
||||
uint64_t nbKey;
|
||||
|
||||
readSize(istream, nbKey);
|
||||
std::vector<LweSecretKey> secretKeys;
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
secretKeys.push_back(readLweSecretKey(istream));
|
||||
}
|
||||
|
||||
readSize(istream, nbKey);
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
bootstrapKeys.push_back(readLweBootstrapKey(istream));
|
||||
}
|
||||
|
||||
readSize(istream, nbKey);
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
keyswitchKeys.push_back(readLweKeyswitchKey(istream));
|
||||
}
|
||||
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
readSize(istream, nbKey);
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream));
|
||||
}
|
||||
|
||||
std::string clientParametersString;
|
||||
istream >> clientParametersString;
|
||||
auto clientParameters =
|
||||
llvm::json::parse<ClientParameters>(clientParametersString);
|
||||
|
||||
if (!clientParameters) {
|
||||
return std::unique_ptr<KeySet>(nullptr);
|
||||
}
|
||||
|
||||
auto csprng = ConcreteCSPRNG(0);
|
||||
auto keySet =
|
||||
KeySet::fromKeys(clientParameters.get(), secretKeys, bootstrapKeys,
|
||||
keyswitchKeys, packingKeyswitchKeys, std::move(csprng));
|
||||
|
||||
return std::move(keySet.value());
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const KeySet &keySet) {
|
||||
auto secretKeys = keySet.getSecretKeys();
|
||||
writeSize(ostream, secretKeys.size());
|
||||
for (auto sk : secretKeys) {
|
||||
ostream << sk;
|
||||
}
|
||||
|
||||
auto bootstrapKeys = keySet.getBootstrapKeys();
|
||||
writeSize(ostream, bootstrapKeys.size());
|
||||
for (auto bsk : bootstrapKeys) {
|
||||
ostream << bsk;
|
||||
}
|
||||
|
||||
auto keyswitchKeys = keySet.getKeyswitchKeys();
|
||||
writeSize(ostream, keyswitchKeys.size());
|
||||
for (auto ksk : keyswitchKeys) {
|
||||
ostream << ksk;
|
||||
}
|
||||
|
||||
auto packingKeyswitchKeys = keySet.getPackingKeyswitchKeys();
|
||||
writeSize(ostream, packingKeyswitchKeys.size());
|
||||
for (auto pksk : packingKeyswitchKeys) {
|
||||
ostream << pksk;
|
||||
}
|
||||
|
||||
auto clientParametersJson = llvm::json::Value(keySet.clientParameters());
|
||||
std::string clientParametersString;
|
||||
llvm::raw_string_ostream clientParametersStringBuffer(clientParametersString);
|
||||
clientParametersStringBuffer << clientParametersJson;
|
||||
ostream << clientParametersString;
|
||||
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
// EvaluationKey ////////////////////////////////
|
||||
|
||||
EvaluationKeys readEvaluationKeys(std::istream &istream) {
|
||||
uint64_t nbKey;
|
||||
readSize(istream, nbKey);
|
||||
std::vector<LweBootstrapKey> bootstrapKeys;
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
bootstrapKeys.push_back(readLweBootstrapKey(istream));
|
||||
}
|
||||
readSize(istream, nbKey);
|
||||
std::vector<LweKeyswitchKey> keyswitchKeys;
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
keyswitchKeys.push_back(readLweKeyswitchKey(istream));
|
||||
}
|
||||
std::vector<PackingKeyswitchKey> packingKeyswitchKeys;
|
||||
readSize(istream, nbKey);
|
||||
for (uint64_t i = 0; i < nbKey; i++) {
|
||||
packingKeyswitchKeys.push_back(readPackingKeyswitchKey(istream));
|
||||
}
|
||||
return EvaluationKeys(keyswitchKeys, bootstrapKeys, packingKeyswitchKeys);
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
const EvaluationKeys &evaluationKeys) {
|
||||
auto bootstrapKeys = evaluationKeys.getBootstrapKeys();
|
||||
writeSize(ostream, bootstrapKeys.size());
|
||||
for (auto bsk : bootstrapKeys) {
|
||||
ostream << bsk;
|
||||
}
|
||||
auto keyswitchKeys = evaluationKeys.getKeyswitchKeys();
|
||||
writeSize(ostream, keyswitchKeys.size());
|
||||
for (auto ksk : keyswitchKeys) {
|
||||
ostream << ksk;
|
||||
}
|
||||
auto packingKeyswitchKeys = evaluationKeys.getPackingKeyswitchKeys();
|
||||
writeSize(ostream, packingKeyswitchKeys.size());
|
||||
for (auto pksk : packingKeyswitchKeys) {
|
||||
ostream << pksk;
|
||||
}
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
}
|
||||
|
||||
// TensorData ///////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
std::ostream &serializeScalarDataRaw(T value, std::ostream &ostream) {
|
||||
writeWord<uint64_t>(ostream, sizeof(T) * 8);
|
||||
writeWord<uint8_t>(ostream, std::is_signed<T>());
|
||||
writeWord<T>(ostream, value);
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &serializeScalarData(const ScalarData &sd, std::ostream &ostream) {
|
||||
switch (sd.getType()) {
|
||||
case ElementType::u64:
|
||||
return serializeScalarDataRaw<uint64_t>(sd.getValue<uint64_t>(), ostream);
|
||||
case ElementType::i64:
|
||||
return serializeScalarDataRaw<int64_t>(sd.getValue<int64_t>(), ostream);
|
||||
case ElementType::u32:
|
||||
return serializeScalarDataRaw<uint32_t>(sd.getValue<uint32_t>(), ostream);
|
||||
case ElementType::i32:
|
||||
return serializeScalarDataRaw<int32_t>(sd.getValue<int32_t>(), ostream);
|
||||
case ElementType::u16:
|
||||
return serializeScalarDataRaw<uint16_t>(sd.getValue<uint16_t>(), ostream);
|
||||
case ElementType::i16:
|
||||
return serializeScalarDataRaw<int16_t>(sd.getValue<int16_t>(), ostream);
|
||||
case ElementType::u8:
|
||||
return serializeScalarDataRaw<uint8_t>(sd.getValue<uint8_t>(), ostream);
|
||||
case ElementType::i8:
|
||||
return serializeScalarDataRaw<int8_t>(sd.getValue<int8_t>(), ostream);
|
||||
}
|
||||
|
||||
return ostream;
|
||||
}
|
||||
|
||||
template <typename T> ScalarData unserializeScalarValue(std::istream &istream) {
|
||||
T value;
|
||||
readWord(istream, value);
|
||||
return ScalarData(value);
|
||||
}
|
||||
|
||||
outcome::checked<ScalarData, StringError>
|
||||
unserializeScalarData(std::istream &istream) {
|
||||
uint64_t scalarWidth;
|
||||
readWord(istream, scalarWidth);
|
||||
|
||||
switch (scalarWidth) {
|
||||
case 64:
|
||||
case 32:
|
||||
case 16:
|
||||
case 8:
|
||||
break;
|
||||
default:
|
||||
return StringError("Scalar width must be either 64, 32, 16 or 8, but got ")
|
||||
<< scalarWidth;
|
||||
}
|
||||
|
||||
uint8_t scalarSignedness;
|
||||
readWord(istream, scalarSignedness);
|
||||
|
||||
if (scalarSignedness != 0 && scalarSignedness != 1) {
|
||||
return StringError("Numerical value for scalar signedness must be either "
|
||||
"0 or 1, but got ")
|
||||
<< scalarSignedness;
|
||||
}
|
||||
|
||||
switch (scalarWidth) {
|
||||
case 64:
|
||||
return (scalarSignedness) ? unserializeScalarValue<int64_t>(istream)
|
||||
: unserializeScalarValue<uint64_t>(istream);
|
||||
case 32:
|
||||
return (scalarSignedness) ? unserializeScalarValue<int32_t>(istream)
|
||||
: unserializeScalarValue<uint32_t>(istream);
|
||||
case 16:
|
||||
return (scalarSignedness) ? unserializeScalarValue<int16_t>(istream)
|
||||
: unserializeScalarValue<uint16_t>(istream);
|
||||
case 8:
|
||||
return (scalarSignedness) ? unserializeScalarValue<int8_t>(istream)
|
||||
: unserializeScalarValue<uint8_t>(istream);
|
||||
}
|
||||
|
||||
assert(false && "Unhandled scalar type");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes,
|
||||
std::istream &istream) {
|
||||
// getElementPointer is not valid if the tensor contains no data
|
||||
if (values_and_sizes.getNumElements() > 0) {
|
||||
readWords(istream, values_and_sizes.getElementPointer<T>(0),
|
||||
values_and_sizes.getNumElements());
|
||||
}
|
||||
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
|
||||
std::ostream &ostream) {
|
||||
switch (values_and_sizes.getElementType()) {
|
||||
case ElementType::u64:
|
||||
return serializeTensorDataRaw<uint64_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint64_t>(), ostream);
|
||||
case ElementType::i64:
|
||||
return serializeTensorDataRaw<int64_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int64_t>(), ostream);
|
||||
case ElementType::u32:
|
||||
return serializeTensorDataRaw<uint32_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint32_t>(), ostream);
|
||||
case ElementType::i32:
|
||||
return serializeTensorDataRaw<int32_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int32_t>(), ostream);
|
||||
case ElementType::u16:
|
||||
return serializeTensorDataRaw<uint16_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint16_t>(), ostream);
|
||||
case ElementType::i16:
|
||||
return serializeTensorDataRaw<int16_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int16_t>(), ostream);
|
||||
case ElementType::u8:
|
||||
return serializeTensorDataRaw<uint8_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint8_t>(), ostream);
|
||||
case ElementType::i8:
|
||||
return serializeTensorDataRaw<int8_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int8_t>(), ostream);
|
||||
}
|
||||
|
||||
assert(false && "Unhandled element type");
|
||||
}
|
||||
|
||||
outcome::checked<TensorData, StringError>
|
||||
unserializeTensorData(std::istream &istream) {
|
||||
|
||||
if (incorrectMode(istream)) {
|
||||
return StringError("Stream is in incorrect mode");
|
||||
}
|
||||
|
||||
uint64_t numDimensions;
|
||||
readWord(istream, numDimensions);
|
||||
|
||||
std::vector<size_t> dims;
|
||||
|
||||
for (uint64_t i = 0; i < numDimensions; i++) {
|
||||
int64_t dimSize;
|
||||
readWord(istream, dimSize);
|
||||
dims.push_back(dimSize);
|
||||
}
|
||||
|
||||
uint64_t elementWidth;
|
||||
readWord(istream, elementWidth);
|
||||
|
||||
switch (elementWidth) {
|
||||
case 64:
|
||||
case 32:
|
||||
case 16:
|
||||
case 8:
|
||||
break;
|
||||
default:
|
||||
return StringError("Element width must be either 64, 32, 16 or 8, but got ")
|
||||
<< elementWidth;
|
||||
}
|
||||
|
||||
uint8_t elementSignedness;
|
||||
readWord(istream, elementSignedness);
|
||||
|
||||
if (elementSignedness != 0 && elementSignedness != 1) {
|
||||
return StringError("Numerical value for element signedness must be either "
|
||||
"0 or 1, but got ")
|
||||
<< elementSignedness;
|
||||
}
|
||||
|
||||
TensorData result(dims, elementWidth, elementSignedness == 1);
|
||||
|
||||
switch (result.getElementType()) {
|
||||
case ElementType::u64:
|
||||
unserializeTensorDataElements<uint64_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
unserializeTensorDataElements<int64_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
unserializeTensorDataElements<uint32_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
unserializeTensorDataElements<int32_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
unserializeTensorDataElements<uint16_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
unserializeTensorDataElements<int16_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
unserializeTensorDataElements<uint8_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
unserializeTensorDataElements<int8_t>(result, istream);
|
||||
break;
|
||||
}
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
|
||||
std::ostream &ostream) {
|
||||
writeWord<uint8_t>(ostream, sotd.isTensor());
|
||||
|
||||
if (sotd.isTensor())
|
||||
return serializeTensorData(sotd.getTensor(), ostream);
|
||||
else
|
||||
return serializeScalarData(sotd.getScalar(), ostream);
|
||||
}
|
||||
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
unserializeScalarOrTensorData(std::istream &istream) {
|
||||
uint8_t isTensor;
|
||||
readWord(istream, isTensor);
|
||||
|
||||
if (isTensor != 0 && isTensor != 1) {
|
||||
return StringError("Numerical value indicating whether a data element is a "
|
||||
"tensor must be either 0 or 1, but got ")
|
||||
<< isTensor;
|
||||
}
|
||||
|
||||
if (isTensor) {
|
||||
auto tdOrErr = unserializeTensorData(istream);
|
||||
|
||||
if (tdOrErr.has_error())
|
||||
return std::move(tdOrErr.error());
|
||||
else
|
||||
return ScalarOrTensorData(std::move(tdOrErr.value()));
|
||||
} else {
|
||||
auto tdOrErr = unserializeScalarData(istream);
|
||||
|
||||
if (tdOrErr.has_error())
|
||||
return std::move(tdOrErr.error());
|
||||
else
|
||||
return ScalarOrTensorData(std::move(tdOrErr.value()));
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream &serializeVectorOfScalarOrTensorData(
|
||||
const std::vector<SharedScalarOrTensorData> &v, std::ostream &ostream) {
|
||||
writeSize(ostream, v.size());
|
||||
for (auto &sotd : v) {
|
||||
serializeScalarOrTensorData(sotd.get(), ostream);
|
||||
if (!ostream.good()) {
|
||||
return ostream;
|
||||
}
|
||||
}
|
||||
return ostream;
|
||||
}
|
||||
outcome::checked<std::vector<SharedScalarOrTensorData>, StringError>
|
||||
unserializeVectorOfScalarOrTensorData(std::istream &istream) {
|
||||
uint64_t nbElt;
|
||||
readSize(istream, nbElt);
|
||||
std::vector<SharedScalarOrTensorData> v;
|
||||
for (uint64_t i = 0; i < nbElt; i++) {
|
||||
OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream));
|
||||
v.push_back(SharedScalarOrTensorData(std::move(elt)));
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -0,0 +1,20 @@
|
||||
add_compile_options(-fexceptions -fsized-deallocation -fno-rtti)
|
||||
|
||||
add_mlir_library(
|
||||
ConcretelangCommon
|
||||
Protocol.cpp
|
||||
CRT.cpp
|
||||
Csprng.cpp
|
||||
Keys.cpp
|
||||
Keysets.cpp
|
||||
Transformers.cpp
|
||||
Values.cpp
|
||||
DEPENDS
|
||||
concrete-protocol
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
concrete_cpu
|
||||
kj
|
||||
capnp)
|
||||
|
||||
target_include_directories(ConcretelangCommon PUBLIC ${CONCRETE_CPU_INCLUDE_DIR})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user