From 7d785eebec38f2ac68fdbcb0c7d9e6845bc394cd Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 2 Dec 2022 10:01:26 +0100 Subject: [PATCH] feat(rust): support serialization --- .../concretelang-c/Support/CompilerEngine.h | 46 ++- compiler/include/concretelang/CAPI/Wrappers.h | 2 +- compiler/lib/Bindings/Rust/src/compiler.rs | 266 +++++++++++++++++- compiler/lib/CAPI/Support/CompilerEngine.cpp | 121 +++++++- 4 files changed, 415 insertions(+), 20 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 152260253..cf42ea795 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -24,7 +24,7 @@ extern "C" { #define DEFINE_C_API_STRUCT(name, storage) \ struct name { \ storage *ptr; \ - char *error; \ + const char *error; \ }; \ typedef struct name name @@ -86,6 +86,31 @@ DEFINE_NULL_PTR_CHECKER(compilationFeedbackIsNull, CompilationFeedback); /// allocated memory for and know how to free. MLIR_CAPI_EXPORTED void mlirStringRefDestroy(MlirStringRef str); +MLIR_CAPI_EXPORTED bool mlirStringRefIsNull(MlirStringRef str) { + return str.data == NULL; +} + +/// ********** BufferRef CAPI ************************************************** + +/// A struct for binary buffers. +/// +/// Contraty to MlirStringRef, it doesn't assume the pointer point to a null +/// terminated string and the data should be considered as is in binary form. +/// Useful for serialized objects. +typedef struct BufferRef { + const char *data; + size_t length; + const char *error; +} BufferRef; + +MLIR_CAPI_EXPORTED void bufferRefDestroy(BufferRef buffer); + +MLIR_CAPI_EXPORTED bool bufferRefIsNull(BufferRef buffer) { + return buffer.data == NULL; +} + +MLIR_CAPI_EXPORTED BufferRef bufferRefCreate(const char *buffer, size_t length); + /// ********** CompilationTarget CAPI ****************************************** enum CompilationTarget { @@ -195,6 +220,11 @@ MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server); /// ********** ClientParameters CAPI ******************************************* +MLIR_CAPI_EXPORTED BufferRef clientParametersSerialize(ClientParameters params); + +MLIR_CAPI_EXPORTED ClientParameters +clientParametersUnserialize(BufferRef buffer); + MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params); /// ********** KeySet CAPI ***************************************************** @@ -218,6 +248,10 @@ MLIR_CAPI_EXPORTED void keySetCacheDestroy(KeySetCache keySetCache); /// ********** EvaluationKeys CAPI ********************************************* +MLIR_CAPI_EXPORTED BufferRef evaluationKeysSerialize(EvaluationKeys keys); + +MLIR_CAPI_EXPORTED EvaluationKeys evaluationKeysUnserialize(BufferRef buffer); + MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys); /// ********** LambdaArgument CAPI ********************************************* @@ -257,6 +291,11 @@ MLIR_CAPI_EXPORTED void lambdaArgumentDestroy(LambdaArgument lambdaArg); /// ********** PublicArguments CAPI ******************************************** +MLIR_CAPI_EXPORTED BufferRef publicArgumentsSerialize(PublicArguments args); + +MLIR_CAPI_EXPORTED PublicArguments +publicArgumentsUnserialize(BufferRef buffer, ClientParameters params); + MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs); /// ********** PublicResult CAPI *********************************************** @@ -264,6 +303,11 @@ MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs); MLIR_CAPI_EXPORTED LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet); +MLIR_CAPI_EXPORTED BufferRef publicResultSerialize(PublicResult result); + +MLIR_CAPI_EXPORTED PublicResult +publicResultUnserialize(BufferRef buffer, ClientParameters params); + MLIR_CAPI_EXPORTED void publicResultDestroy(PublicResult publicResult); /// ********** CompilationFeedback CAPI **************************************** diff --git a/compiler/include/concretelang/CAPI/Wrappers.h b/compiler/include/concretelang/CAPI/Wrappers.h index 29a7e329a..2506209d9 100644 --- a/compiler/include/concretelang/CAPI/Wrappers.h +++ b/compiler/include/concretelang/CAPI/Wrappers.h @@ -22,7 +22,7 @@ static inline cpptype *unwrap(name c) { \ return static_cast(c.ptr); \ } \ - static inline char *getErrorPtr(name c) { return c.error; } + static inline const char *getErrorPtr(name c) { return c.error; } DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilerEngine, mlir::concretelang::CompilerEngine) diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index 6e2dd5eb6..e49bb87c3 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -1,22 +1,22 @@ //! Compiler module -use std::{ffi::CStr, path::Path}; - use crate::mlir::ffi::*; +use std::os::raw::c_char; +use std::{ffi::CStr, path::Path}; #[derive(Debug)] pub struct CompilerError(String); /// Retreive buffer of the error message from a C struct. trait CStructErrorMsg { - fn error_msg(&self) -> *mut i8; + fn error_msg(&self) -> *const i8; } /// All C struct can return a pointer to the allocated error message. macro_rules! impl_CStructErrorMsg { ([$($t:ty),+]) => { $(impl CStructErrorMsg for $t { - fn error_msg(&self) -> *mut i8 { + fn error_msg(&self) -> *const i8 { self.error } })* @@ -32,7 +32,9 @@ impl_CStructErrorMsg! {[ PublicResult, KeySet, KeySetCache, - LambdaArgument + LambdaArgument, + BufferRef, + EvaluationKeys ]} /// Construct a rust error message from a buffer in the C struct. @@ -58,6 +60,19 @@ unsafe fn mlir_string_ref_to_string(str_ref: MlirStringRef) -> String { result } +/// Create a vector of bytes from a BufferRef and free its memory. +/// +/// # SAFETY +/// +/// This should only be used with string refs returned by the compiler. +unsafe fn buffer_ref_to_bytes(buffer_ref: BufferRef) -> Vec { + let result = + std::slice::from_raw_parts(buffer_ref.data as *const c_char, buffer_ref.length as usize) + .to_vec(); + bufferRefDestroy(buffer_ref); + result +} + /// Parse the MLIR code and returns it. /// /// The function parse the provided MLIR textual representation and returns it. It would fail with @@ -82,7 +97,7 @@ pub fn round_trip(mlir_code: &str) -> Result { let compilation_result = compilerEngineCompile( engine, MlirStringRef { - data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char, + data: mlir_code_buffer.as_ptr() as *const c_char, length: mlir_code_buffer.len() as size_t, }, CompilationTarget_ROUND_TRIP, @@ -128,11 +143,11 @@ impl LibrarySupport { let runtime_library_path_buffer = runtime_library_path.as_bytes(); let support = librarySupportCreateDefault( MlirStringRef { - data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char, + data: output_dir_path_buffer.as_ptr() as *const c_char, length: output_dir_path_buffer.len() as size_t, }, MlirStringRef { - data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char, + data: runtime_library_path_buffer.as_ptr() as *const c_char, length: runtime_library_path_buffer.len() as size_t, }, ); @@ -160,7 +175,7 @@ impl LibrarySupport { let result = librarySupportCompile( self.support, MlirStringRef { - data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char, + data: mlir_code_buffer.as_ptr() as *const c_char, length: mlir_code_buffer.len() as size_t, }, options, @@ -281,7 +296,7 @@ impl ClientSupport { Some(path) => { let cache_path_buffer = path.to_str().unwrap().as_bytes(); let cache = keySetCacheCreate(MlirStringRef { - data: cache_path_buffer.as_ptr() as *const std::os::raw::c_char, + data: cache_path_buffer.as_ptr() as *const c_char, length: cache_path_buffer.len() as size_t, }); if keySetCacheIsNull(cache) { @@ -383,6 +398,174 @@ impl ClientSupport { } } +// TODO: implement traits for C Struct that are serializable and reduce code for serialization and maybe refactor other functions. +// destroy and is_null could be implemented for other struct as well. +// +// trait Serializable { +// fn into_buffer_ref(self) -> BufferRef; +// fn from_buffer_ref(buff: BufferRef, params: Option) -> Self; +// fn is_null(self) -> bool; +// fn destroy(self); +// } + +// fn serialize(to_serialize: T) -> Result, CompilerError> { +// unsafe { +// let serialized_ref = to_serialize.into_buffer_ref(); +// if bufferRefIsNull(serialized_ref) { +// let error_msg = get_error_msg_from_ctype(serialized_ref); +// bufferRefDestroy(serialized_ref); +// return Err(CompilerError(error_msg)); +// } +// let serialized = buffer_ref_to_bytes(serialized_ref); +// Ok(serialized) +// } +// } + +// fn unserialize( +// serialized: &Vec, +// client_parameters: Option, +// ) -> Result { +// unsafe { +// let serialized_ref = bufferRefCreate( +// serialized.as_ptr() as *const c_char, +// serialized.len().try_into().unwrap(), +// ); +// let serialized = T::from_buffer_ref(serialized_ref, client_parameters); +// if serialized.is_null() { +// let error_msg = get_error_msg_from_ctype(serialized); +// serialized.destroy(); +// return Err(CompilerError(error_msg)); +// } +// Ok(serialized) +// } +// } + +impl PublicArguments { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = publicArgumentsSerialize(self); + if bufferRefIsNull(serialized_ref) { + let error_msg = get_error_msg_from_ctype(serialized_ref); + bufferRefDestroy(serialized_ref); + return Err(CompilerError(error_msg)); + } + let serialized = buffer_ref_to_bytes(serialized_ref); + Ok(serialized) + } + } + pub fn unserialize( + serialized: &Vec, + client_parameters: ClientParameters, + ) -> Result { + unsafe { + let serialized_ref = bufferRefCreate( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ); + let public_args = publicArgumentsUnserialize(serialized_ref, client_parameters); + if publicArgumentsIsNull(public_args) { + let error_msg = get_error_msg_from_ctype(public_args); + publicArgumentsDestroy(public_args); + return Err(CompilerError(error_msg)); + } + Ok(public_args) + } + } +} + +impl PublicResult { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = publicResultSerialize(self); + if bufferRefIsNull(serialized_ref) { + let error_msg = get_error_msg_from_ctype(serialized_ref); + bufferRefDestroy(serialized_ref); + return Err(CompilerError(error_msg)); + } + let serialized = buffer_ref_to_bytes(serialized_ref); + Ok(serialized) + } + } + pub fn unserialize( + serialized: &Vec, + client_parameters: ClientParameters, + ) -> Result { + unsafe { + let serialized_ref = bufferRefCreate( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ); + let public_result = publicResultUnserialize(serialized_ref, client_parameters); + if publicResultIsNull(public_result) { + let error_msg = get_error_msg_from_ctype(public_result); + publicResultDestroy(public_result); + return Err(CompilerError(error_msg)); + } + Ok(public_result) + } + } +} + +impl EvaluationKeys { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = evaluationKeysSerialize(self); + if bufferRefIsNull(serialized_ref) { + let error_msg = get_error_msg_from_ctype(serialized_ref); + bufferRefDestroy(serialized_ref); + return Err(CompilerError(error_msg)); + } + let serialized = buffer_ref_to_bytes(serialized_ref); + Ok(serialized) + } + } + pub fn unserialize(serialized: &Vec) -> Result { + unsafe { + let serialized_ref = bufferRefCreate( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ); + let eval_keys = evaluationKeysUnserialize(serialized_ref); + if evaluationKeysIsNull(eval_keys) { + let error_msg = get_error_msg_from_ctype(eval_keys); + evaluationKeysDestroy(eval_keys); + return Err(CompilerError(error_msg)); + } + Ok(eval_keys) + } + } +} + +impl ClientParameters { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = clientParametersSerialize(self); + if bufferRefIsNull(serialized_ref) { + let error_msg = get_error_msg_from_ctype(serialized_ref); + bufferRefDestroy(serialized_ref); + return Err(CompilerError(error_msg)); + } + let serialized = buffer_ref_to_bytes(serialized_ref); + Ok(serialized) + } + } + pub fn unserialize(serialized: &Vec) -> Result { + unsafe { + let serialized_ref = bufferRefCreate( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ); + let params = clientParametersUnserialize(serialized_ref); + if clientParametersIsNull(params) { + let error_msg = get_error_msg_from_ctype(params); + clientParametersDestroy(params); + return Err(CompilerError(error_msg)); + } + Ok(params) + } + } +} + #[cfg(test)] mod test { use std::env; @@ -531,6 +714,69 @@ mod test { } } + #[test] + fn test_compiler_compile_and_exec_with_serialization() { + unsafe { + let module_to_compile = " + func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> + return %0 : !FHE.eint<5> + }"; + let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { + Ok(val) => val + "/lib/libConcretelangRuntime.so", + Err(_e) => "".to_string(), + }; + let temp_dir = + TempDir::new("rust_test_compiler_compile_and_exec_with_serialization").unwrap(); + let lib_support = LibrarySupport::new( + temp_dir.path().to_str().unwrap(), + runtime_library_path.as_str(), + ) + .unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(result).unwrap(); + let client_params = lib_support.load_client_parameters(result).unwrap(); + // serialize client parameters + let serialized_params = client_params.serialize().unwrap(); + let client_params = ClientParameters::unserialize(&serialized_params).unwrap(); + // create client support + let client_support = ClientSupport::new(client_params, None).unwrap(); + let key_set = client_support.keyset(None, None).unwrap(); + let eval_keys = keySetGetEvaluationKeys(key_set); + // serialize eval keys + let serialized_eval_keys = eval_keys.serialize().unwrap(); + let eval_keys = EvaluationKeys::unserialize(&serialized_eval_keys).unwrap(); + // build lambda arguments from scalar and encrypt them + let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)]; + let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); + // free args + args.map(|arg| lambdaArgumentDestroy(arg)); + // serialize args + let serialized_encrypted_args = encrypted_args.serialize().unwrap(); + let encrypted_args = + PublicArguments::unserialize(&serialized_encrypted_args, client_params).unwrap(); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(server_lambda, encrypted_args, eval_keys) + .unwrap(); + // serialize result + let serialized_encrypted_result = encrypted_result.serialize().unwrap(); + let encrypted_result = + PublicResult::unserialize(&serialized_encrypted_result, client_params).unwrap(); + // decrypt the result of execution + let result_arg = client_support + .decrypt_result(encrypted_result, key_set) + .unwrap(); + // get the scalar value from the result lambda argument + let result = lambdaArgumentGetScalar(result_arg); + assert_eq!(result, 6); + } + } + #[test] fn test_tensor_lambda_argument() { unsafe { diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 26250221f..5e9fea2ec 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -17,14 +17,48 @@ auto *cpp = unwrap(c_struct); \ if (cpp != NULL) \ delete cpp; \ - char *error = getErrorPtr(c_struct); \ + const char *error = getErrorPtr(c_struct); \ if (error != NULL) \ delete[] error; +/// ********** BufferRef CAPI ************************************************** + +BufferRef bufferRefCreate(const char *buffer, size_t length) { + return BufferRef{buffer, length, NULL}; +} + +BufferRef bufferRefFromString(std::string str) { + char *buffer = new char[str.size()]; + memcpy(buffer, str.c_str(), str.size()); + return bufferRefCreate(buffer, str.size()); +} + +BufferRef bufferRefFromStringError(std::string error) { + char *buffer = new char[error.size()]; + memcpy(buffer, error.c_str(), error.size()); + return BufferRef{NULL, 0, buffer}; +} + +void bufferRefDestroy(BufferRef buffer) { + if (buffer.data != NULL) + delete[] buffer.data; + if (buffer.error != NULL) + delete[] buffer.error; +} + /// ********** Utilities ******************************************************* void mlirStringRefDestroy(MlirStringRef str) { delete[] str.data; } +template BufferRef serialize(T toSerialize) { + std::ostringstream ostream(std::ios::binary); + auto voidOrError = unwrap(toSerialize)->serialize(ostream); + if (voidOrError.has_error()) { + return bufferRefFromStringError(voidOrError.error().mesg); + } + return bufferRefFromString(ostream.str()); +} + /// ********** CompilationOptions CAPI ***************************************** CompilationOptions @@ -273,10 +307,31 @@ void librarySupportDestroy(LibrarySupport support) { C_STRUCT_CLEANER(support) } /// ********** ServerLamda CAPI ************************************************ -void serverLambdaDestroy(ServerLambda server) { C_STRUCT_CLEANER(server) } +void serverLambdaDestroy(ServerLambda server){C_STRUCT_CLEANER(server)} /// ********** ClientParameters CAPI ******************************************* +BufferRef clientParametersSerialize(ClientParameters params) { + llvm::json::Value value(*unwrap(params)); + std::string jsonParams; + llvm::raw_string_ostream ostream(jsonParams); + ostream << value; + char *buffer = new char[jsonParams.size() + 1]; + strcpy(buffer, jsonParams.c_str()); + return bufferRefCreate(buffer, jsonParams.size()); +} + +ClientParameters clientParametersUnserialize(BufferRef buffer) { + std::string json(buffer.data, buffer.length); + auto paramsOrError = + llvm::json::parse(json); + if (!paramsOrError) { + return wrap((mlir::concretelang::ClientParameters *)NULL, + llvm::toString(paramsOrError.takeError())); + } + return wrap(new mlir::concretelang::ClientParameters(paramsOrError.get())); +} + void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)} /// ********** KeySet CAPI ***************************************************** @@ -318,12 +373,31 @@ KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache, return wrap(keySetOrError.value().release()); } -void keySetCacheDestroy(KeySetCache keySetCache) { - C_STRUCT_CLEANER(keySetCache) -} +void keySetCacheDestroy(KeySetCache keySetCache){C_STRUCT_CLEANER(keySetCache)} /// ********** EvaluationKeys CAPI ********************************************* +BufferRef evaluationKeysSerialize(EvaluationKeys keys) { + std::ostringstream ostream(std::ios::binary); + concretelang::clientlib::operator<<(ostream, *unwrap(keys)); + if (ostream.fail()) { + return bufferRefFromStringError( + "output stream failure during evaluation keys serialization"); + } + return bufferRefFromString(ostream.str()); +} + +EvaluationKeys evaluationKeysUnserialize(BufferRef buffer) { + std::stringstream istream(std::string(buffer.data, buffer.length)); + concretelang::clientlib::EvaluationKeys evaluationKeys; + concretelang::clientlib::operator>>(istream, evaluationKeys); + if (istream.fail()) { + return wrap((concretelang::clientlib::EvaluationKeys *)NULL, + "input stream failure during evaluation keys unserialization"); + } + return wrap(new concretelang::clientlib::EvaluationKeys(evaluationKeys)); +} + void evaluationKeysDestroy(EvaluationKeys evaluationKeys) { C_STRUCT_CLEANER(evaluationKeys); } @@ -540,12 +614,27 @@ PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, return wrap(publicArgsOrError.get().release()); } -void lambdaArgumentDestroy(LambdaArgument lambdaArg) { - C_STRUCT_CLEANER(lambdaArg) -} +void lambdaArgumentDestroy(LambdaArgument lambdaArg){ + C_STRUCT_CLEANER(lambdaArg)} /// ********** PublicArguments CAPI ******************************************** +BufferRef publicArgumentsSerialize(PublicArguments args) { + return serialize(args); +} + +PublicArguments publicArgumentsUnserialize(BufferRef buffer, + ClientParameters params) { + std::stringstream istream(std::string(buffer.data, buffer.length)); + auto argsOrError = concretelang::clientlib::PublicArguments::unserialize( + *unwrap(params), istream); + if (!argsOrError) { + return wrap((concretelang::clientlib::PublicArguments *)NULL, + argsOrError.error().mesg); + } + return wrap(argsOrError.value().release()); +} + void publicArgumentsDestroy(PublicArguments publicArgs){ C_STRUCT_CLEANER(publicArgs)} @@ -563,6 +652,22 @@ LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) { return wrap(lambdaArgOrError.get().release()); } +BufferRef publicResultSerialize(PublicResult result) { + return serialize(result); +} + +PublicResult publicResultUnserialize(BufferRef buffer, + ClientParameters params) { + std::stringstream istream(std::string(buffer.data, buffer.length)); + auto resultOrError = concretelang::clientlib::PublicResult::unserialize( + *unwrap(params), istream); + if (!resultOrError) { + return wrap((concretelang::clientlib::PublicResult *)NULL, + resultOrError.error().mesg); + } + return wrap(resultOrError.value().release()); +} + void publicResultDestroy(PublicResult publicResult) { C_STRUCT_CLEANER(publicResult) }