feat(rust): support serialization

This commit is contained in:
youben11
2022-12-02 10:01:26 +01:00
committed by Ayoub Benaissa
parent fbc60097ab
commit 7d785eebec
4 changed files with 415 additions and 20 deletions

View File

@@ -24,7 +24,7 @@ extern "C" {
#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
char *error; \
const char *error; \
}; \
typedef struct name name
@@ -86,6 +86,31 @@ DEFINE_NULL_PTR_CHECKER(compilationFeedbackIsNull, CompilationFeedback);
/// allocated memory for and know how to free.
MLIR_CAPI_EXPORTED void mlirStringRefDestroy(MlirStringRef str);
MLIR_CAPI_EXPORTED bool mlirStringRefIsNull(MlirStringRef str) {
return str.data == NULL;
}
/// ********** BufferRef CAPI **************************************************
/// A struct for binary buffers.
///
/// Contraty to MlirStringRef, it doesn't assume the pointer point to a null
/// terminated string and the data should be considered as is in binary form.
/// Useful for serialized objects.
typedef struct BufferRef {
const char *data;
size_t length;
const char *error;
} BufferRef;
MLIR_CAPI_EXPORTED void bufferRefDestroy(BufferRef buffer);
MLIR_CAPI_EXPORTED bool bufferRefIsNull(BufferRef buffer) {
return buffer.data == NULL;
}
MLIR_CAPI_EXPORTED BufferRef bufferRefCreate(const char *buffer, size_t length);
/// ********** CompilationTarget CAPI ******************************************
enum CompilationTarget {
@@ -195,6 +220,11 @@ MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server);
/// ********** ClientParameters CAPI *******************************************
MLIR_CAPI_EXPORTED BufferRef clientParametersSerialize(ClientParameters params);
MLIR_CAPI_EXPORTED ClientParameters
clientParametersUnserialize(BufferRef buffer);
MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params);
/// ********** KeySet CAPI *****************************************************
@@ -218,6 +248,10 @@ MLIR_CAPI_EXPORTED void keySetCacheDestroy(KeySetCache keySetCache);
/// ********** EvaluationKeys CAPI *********************************************
MLIR_CAPI_EXPORTED BufferRef evaluationKeysSerialize(EvaluationKeys keys);
MLIR_CAPI_EXPORTED EvaluationKeys evaluationKeysUnserialize(BufferRef buffer);
MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
/// ********** LambdaArgument CAPI *********************************************
@@ -257,6 +291,11 @@ MLIR_CAPI_EXPORTED void lambdaArgumentDestroy(LambdaArgument lambdaArg);
/// ********** PublicArguments CAPI ********************************************
MLIR_CAPI_EXPORTED BufferRef publicArgumentsSerialize(PublicArguments args);
MLIR_CAPI_EXPORTED PublicArguments
publicArgumentsUnserialize(BufferRef buffer, ClientParameters params);
MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs);
/// ********** PublicResult CAPI ***********************************************
@@ -264,6 +303,11 @@ MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs);
MLIR_CAPI_EXPORTED LambdaArgument publicResultDecrypt(PublicResult publicResult,
KeySet keySet);
MLIR_CAPI_EXPORTED BufferRef publicResultSerialize(PublicResult result);
MLIR_CAPI_EXPORTED PublicResult
publicResultUnserialize(BufferRef buffer, ClientParameters params);
MLIR_CAPI_EXPORTED void publicResultDestroy(PublicResult publicResult);
/// ********** CompilationFeedback CAPI ****************************************

View File

@@ -22,7 +22,7 @@
static inline cpptype *unwrap(name c) { \
return static_cast<cpptype *>(c.ptr); \
} \
static inline char *getErrorPtr(name c) { return c.error; }
static inline const char *getErrorPtr(name c) { return c.error; }
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilerEngine,
mlir::concretelang::CompilerEngine)

View File

@@ -1,22 +1,22 @@
//! Compiler module
use std::{ffi::CStr, path::Path};
use crate::mlir::ffi::*;
use std::os::raw::c_char;
use std::{ffi::CStr, path::Path};
#[derive(Debug)]
pub struct CompilerError(String);
/// Retreive buffer of the error message from a C struct.
trait CStructErrorMsg {
fn error_msg(&self) -> *mut i8;
fn error_msg(&self) -> *const i8;
}
/// All C struct can return a pointer to the allocated error message.
macro_rules! impl_CStructErrorMsg {
([$($t:ty),+]) => {
$(impl CStructErrorMsg for $t {
fn error_msg(&self) -> *mut i8 {
fn error_msg(&self) -> *const i8 {
self.error
}
})*
@@ -32,7 +32,9 @@ impl_CStructErrorMsg! {[
PublicResult,
KeySet,
KeySetCache,
LambdaArgument
LambdaArgument,
BufferRef,
EvaluationKeys
]}
/// Construct a rust error message from a buffer in the C struct.
@@ -58,6 +60,19 @@ unsafe fn mlir_string_ref_to_string(str_ref: MlirStringRef) -> String {
result
}
/// Create a vector of bytes from a BufferRef and free its memory.
///
/// # SAFETY
///
/// This should only be used with string refs returned by the compiler.
unsafe fn buffer_ref_to_bytes(buffer_ref: BufferRef) -> Vec<c_char> {
let result =
std::slice::from_raw_parts(buffer_ref.data as *const c_char, buffer_ref.length as usize)
.to_vec();
bufferRefDestroy(buffer_ref);
result
}
/// Parse the MLIR code and returns it.
///
/// The function parse the provided MLIR textual representation and returns it. It would fail with
@@ -82,7 +97,7 @@ pub fn round_trip(mlir_code: &str) -> Result<String, CompilerError> {
let compilation_result = compilerEngineCompile(
engine,
MlirStringRef {
data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char,
data: mlir_code_buffer.as_ptr() as *const c_char,
length: mlir_code_buffer.len() as size_t,
},
CompilationTarget_ROUND_TRIP,
@@ -128,11 +143,11 @@ impl LibrarySupport {
let runtime_library_path_buffer = runtime_library_path.as_bytes();
let support = librarySupportCreateDefault(
MlirStringRef {
data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char,
data: output_dir_path_buffer.as_ptr() as *const c_char,
length: output_dir_path_buffer.len() as size_t,
},
MlirStringRef {
data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char,
data: runtime_library_path_buffer.as_ptr() as *const c_char,
length: runtime_library_path_buffer.len() as size_t,
},
);
@@ -160,7 +175,7 @@ impl LibrarySupport {
let result = librarySupportCompile(
self.support,
MlirStringRef {
data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char,
data: mlir_code_buffer.as_ptr() as *const c_char,
length: mlir_code_buffer.len() as size_t,
},
options,
@@ -281,7 +296,7 @@ impl ClientSupport {
Some(path) => {
let cache_path_buffer = path.to_str().unwrap().as_bytes();
let cache = keySetCacheCreate(MlirStringRef {
data: cache_path_buffer.as_ptr() as *const std::os::raw::c_char,
data: cache_path_buffer.as_ptr() as *const c_char,
length: cache_path_buffer.len() as size_t,
});
if keySetCacheIsNull(cache) {
@@ -383,6 +398,174 @@ impl ClientSupport {
}
}
// TODO: implement traits for C Struct that are serializable and reduce code for serialization and maybe refactor other functions.
// destroy and is_null could be implemented for other struct as well.
//
// trait Serializable {
// fn into_buffer_ref(self) -> BufferRef;
// fn from_buffer_ref(buff: BufferRef, params: Option<ClientParameters>) -> Self;
// fn is_null(self) -> bool;
// fn destroy(self);
// }
// fn serialize<T: Serializable>(to_serialize: T) -> Result<Vec<c_char>, CompilerError> {
// unsafe {
// let serialized_ref = to_serialize.into_buffer_ref();
// if bufferRefIsNull(serialized_ref) {
// let error_msg = get_error_msg_from_ctype(serialized_ref);
// bufferRefDestroy(serialized_ref);
// return Err(CompilerError(error_msg));
// }
// let serialized = buffer_ref_to_bytes(serialized_ref);
// Ok(serialized)
// }
// }
// fn unserialize<T: Serializable + CStructErrorMsg>(
// serialized: &Vec<c_char>,
// client_parameters: Option<ClientParameters>,
// ) -> Result<T, CompilerError> {
// unsafe {
// let serialized_ref = bufferRefCreate(
// serialized.as_ptr() as *const c_char,
// serialized.len().try_into().unwrap(),
// );
// let serialized = T::from_buffer_ref(serialized_ref, client_parameters);
// if serialized.is_null() {
// let error_msg = get_error_msg_from_ctype(serialized);
// serialized.destroy();
// return Err(CompilerError(error_msg));
// }
// Ok(serialized)
// }
// }
impl PublicArguments {
pub fn serialize(self) -> Result<Vec<c_char>, CompilerError> {
unsafe {
let serialized_ref = publicArgumentsSerialize(self);
if bufferRefIsNull(serialized_ref) {
let error_msg = get_error_msg_from_ctype(serialized_ref);
bufferRefDestroy(serialized_ref);
return Err(CompilerError(error_msg));
}
let serialized = buffer_ref_to_bytes(serialized_ref);
Ok(serialized)
}
}
pub fn unserialize(
serialized: &Vec<c_char>,
client_parameters: ClientParameters,
) -> Result<PublicArguments, CompilerError> {
unsafe {
let serialized_ref = bufferRefCreate(
serialized.as_ptr() as *const c_char,
serialized.len().try_into().unwrap(),
);
let public_args = publicArgumentsUnserialize(serialized_ref, client_parameters);
if publicArgumentsIsNull(public_args) {
let error_msg = get_error_msg_from_ctype(public_args);
publicArgumentsDestroy(public_args);
return Err(CompilerError(error_msg));
}
Ok(public_args)
}
}
}
impl PublicResult {
pub fn serialize(self) -> Result<Vec<c_char>, CompilerError> {
unsafe {
let serialized_ref = publicResultSerialize(self);
if bufferRefIsNull(serialized_ref) {
let error_msg = get_error_msg_from_ctype(serialized_ref);
bufferRefDestroy(serialized_ref);
return Err(CompilerError(error_msg));
}
let serialized = buffer_ref_to_bytes(serialized_ref);
Ok(serialized)
}
}
pub fn unserialize(
serialized: &Vec<c_char>,
client_parameters: ClientParameters,
) -> Result<PublicResult, CompilerError> {
unsafe {
let serialized_ref = bufferRefCreate(
serialized.as_ptr() as *const c_char,
serialized.len().try_into().unwrap(),
);
let public_result = publicResultUnserialize(serialized_ref, client_parameters);
if publicResultIsNull(public_result) {
let error_msg = get_error_msg_from_ctype(public_result);
publicResultDestroy(public_result);
return Err(CompilerError(error_msg));
}
Ok(public_result)
}
}
}
impl EvaluationKeys {
pub fn serialize(self) -> Result<Vec<c_char>, CompilerError> {
unsafe {
let serialized_ref = evaluationKeysSerialize(self);
if bufferRefIsNull(serialized_ref) {
let error_msg = get_error_msg_from_ctype(serialized_ref);
bufferRefDestroy(serialized_ref);
return Err(CompilerError(error_msg));
}
let serialized = buffer_ref_to_bytes(serialized_ref);
Ok(serialized)
}
}
pub fn unserialize(serialized: &Vec<c_char>) -> Result<EvaluationKeys, CompilerError> {
unsafe {
let serialized_ref = bufferRefCreate(
serialized.as_ptr() as *const c_char,
serialized.len().try_into().unwrap(),
);
let eval_keys = evaluationKeysUnserialize(serialized_ref);
if evaluationKeysIsNull(eval_keys) {
let error_msg = get_error_msg_from_ctype(eval_keys);
evaluationKeysDestroy(eval_keys);
return Err(CompilerError(error_msg));
}
Ok(eval_keys)
}
}
}
impl ClientParameters {
pub fn serialize(self) -> Result<Vec<c_char>, CompilerError> {
unsafe {
let serialized_ref = clientParametersSerialize(self);
if bufferRefIsNull(serialized_ref) {
let error_msg = get_error_msg_from_ctype(serialized_ref);
bufferRefDestroy(serialized_ref);
return Err(CompilerError(error_msg));
}
let serialized = buffer_ref_to_bytes(serialized_ref);
Ok(serialized)
}
}
pub fn unserialize(serialized: &Vec<c_char>) -> Result<ClientParameters, CompilerError> {
unsafe {
let serialized_ref = bufferRefCreate(
serialized.as_ptr() as *const c_char,
serialized.len().try_into().unwrap(),
);
let params = clientParametersUnserialize(serialized_ref);
if clientParametersIsNull(params) {
let error_msg = get_error_msg_from_ctype(params);
clientParametersDestroy(params);
return Err(CompilerError(error_msg));
}
Ok(params)
}
}
}
#[cfg(test)]
mod test {
use std::env;
@@ -531,6 +714,69 @@ mod test {
}
}
#[test]
fn test_compiler_compile_and_exec_with_serialization() {
unsafe {
let module_to_compile = "
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}";
let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
Ok(val) => val + "/lib/libConcretelangRuntime.so",
Err(_e) => "".to_string(),
};
let temp_dir =
TempDir::new("rust_test_compiler_compile_and_exec_with_serialization").unwrap();
let lib_support = LibrarySupport::new(
temp_dir.path().to_str().unwrap(),
runtime_library_path.as_str(),
)
.unwrap();
// compile
let result = lib_support.compile(module_to_compile, None).unwrap();
// loading materials from compilation
// - server_lambda: used for execution
// - client_parameters: used for keygen, encryption, and evaluation keys
let server_lambda = lib_support.load_server_lambda(result).unwrap();
let client_params = lib_support.load_client_parameters(result).unwrap();
// serialize client parameters
let serialized_params = client_params.serialize().unwrap();
let client_params = ClientParameters::unserialize(&serialized_params).unwrap();
// create client support
let client_support = ClientSupport::new(client_params, None).unwrap();
let key_set = client_support.keyset(None, None).unwrap();
let eval_keys = keySetGetEvaluationKeys(key_set);
// serialize eval keys
let serialized_eval_keys = eval_keys.serialize().unwrap();
let eval_keys = EvaluationKeys::unserialize(&serialized_eval_keys).unwrap();
// build lambda arguments from scalar and encrypt them
let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)];
let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap();
// free args
args.map(|arg| lambdaArgumentDestroy(arg));
// serialize args
let serialized_encrypted_args = encrypted_args.serialize().unwrap();
let encrypted_args =
PublicArguments::unserialize(&serialized_encrypted_args, client_params).unwrap();
// execute the compiled function on the encrypted arguments
let encrypted_result = lib_support
.server_lambda_call(server_lambda, encrypted_args, eval_keys)
.unwrap();
// serialize result
let serialized_encrypted_result = encrypted_result.serialize().unwrap();
let encrypted_result =
PublicResult::unserialize(&serialized_encrypted_result, client_params).unwrap();
// decrypt the result of execution
let result_arg = client_support
.decrypt_result(encrypted_result, key_set)
.unwrap();
// get the scalar value from the result lambda argument
let result = lambdaArgumentGetScalar(result_arg);
assert_eq!(result, 6);
}
}
#[test]
fn test_tensor_lambda_argument() {
unsafe {

View File

@@ -17,14 +17,48 @@
auto *cpp = unwrap(c_struct); \
if (cpp != NULL) \
delete cpp; \
char *error = getErrorPtr(c_struct); \
const char *error = getErrorPtr(c_struct); \
if (error != NULL) \
delete[] error;
/// ********** BufferRef CAPI **************************************************
BufferRef bufferRefCreate(const char *buffer, size_t length) {
return BufferRef{buffer, length, NULL};
}
BufferRef bufferRefFromString(std::string str) {
char *buffer = new char[str.size()];
memcpy(buffer, str.c_str(), str.size());
return bufferRefCreate(buffer, str.size());
}
BufferRef bufferRefFromStringError(std::string error) {
char *buffer = new char[error.size()];
memcpy(buffer, error.c_str(), error.size());
return BufferRef{NULL, 0, buffer};
}
void bufferRefDestroy(BufferRef buffer) {
if (buffer.data != NULL)
delete[] buffer.data;
if (buffer.error != NULL)
delete[] buffer.error;
}
/// ********** Utilities *******************************************************
void mlirStringRefDestroy(MlirStringRef str) { delete[] str.data; }
template <typename T> BufferRef serialize(T toSerialize) {
std::ostringstream ostream(std::ios::binary);
auto voidOrError = unwrap(toSerialize)->serialize(ostream);
if (voidOrError.has_error()) {
return bufferRefFromStringError(voidOrError.error().mesg);
}
return bufferRefFromString(ostream.str());
}
/// ********** CompilationOptions CAPI *****************************************
CompilationOptions
@@ -273,10 +307,31 @@ void librarySupportDestroy(LibrarySupport support) { C_STRUCT_CLEANER(support) }
/// ********** ServerLamda CAPI ************************************************
void serverLambdaDestroy(ServerLambda server) { C_STRUCT_CLEANER(server) }
void serverLambdaDestroy(ServerLambda server){C_STRUCT_CLEANER(server)}
/// ********** ClientParameters CAPI *******************************************
BufferRef clientParametersSerialize(ClientParameters params) {
llvm::json::Value value(*unwrap(params));
std::string jsonParams;
llvm::raw_string_ostream ostream(jsonParams);
ostream << value;
char *buffer = new char[jsonParams.size() + 1];
strcpy(buffer, jsonParams.c_str());
return bufferRefCreate(buffer, jsonParams.size());
}
ClientParameters clientParametersUnserialize(BufferRef buffer) {
std::string json(buffer.data, buffer.length);
auto paramsOrError =
llvm::json::parse<mlir::concretelang::ClientParameters>(json);
if (!paramsOrError) {
return wrap((mlir::concretelang::ClientParameters *)NULL,
llvm::toString(paramsOrError.takeError()));
}
return wrap(new mlir::concretelang::ClientParameters(paramsOrError.get()));
}
void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)}
/// ********** KeySet CAPI *****************************************************
@@ -318,12 +373,31 @@ KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache,
return wrap(keySetOrError.value().release());
}
void keySetCacheDestroy(KeySetCache keySetCache) {
C_STRUCT_CLEANER(keySetCache)
}
void keySetCacheDestroy(KeySetCache keySetCache){C_STRUCT_CLEANER(keySetCache)}
/// ********** EvaluationKeys CAPI *********************************************
BufferRef evaluationKeysSerialize(EvaluationKeys keys) {
std::ostringstream ostream(std::ios::binary);
concretelang::clientlib::operator<<(ostream, *unwrap(keys));
if (ostream.fail()) {
return bufferRefFromStringError(
"output stream failure during evaluation keys serialization");
}
return bufferRefFromString(ostream.str());
}
EvaluationKeys evaluationKeysUnserialize(BufferRef buffer) {
std::stringstream istream(std::string(buffer.data, buffer.length));
concretelang::clientlib::EvaluationKeys evaluationKeys;
concretelang::clientlib::operator>>(istream, evaluationKeys);
if (istream.fail()) {
return wrap((concretelang::clientlib::EvaluationKeys *)NULL,
"input stream failure during evaluation keys unserialization");
}
return wrap(new concretelang::clientlib::EvaluationKeys(evaluationKeys));
}
void evaluationKeysDestroy(EvaluationKeys evaluationKeys) {
C_STRUCT_CLEANER(evaluationKeys);
}
@@ -540,12 +614,27 @@ PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
return wrap(publicArgsOrError.get().release());
}
void lambdaArgumentDestroy(LambdaArgument lambdaArg) {
C_STRUCT_CLEANER(lambdaArg)
}
void lambdaArgumentDestroy(LambdaArgument lambdaArg){
C_STRUCT_CLEANER(lambdaArg)}
/// ********** PublicArguments CAPI ********************************************
BufferRef publicArgumentsSerialize(PublicArguments args) {
return serialize(args);
}
PublicArguments publicArgumentsUnserialize(BufferRef buffer,
ClientParameters params) {
std::stringstream istream(std::string(buffer.data, buffer.length));
auto argsOrError = concretelang::clientlib::PublicArguments::unserialize(
*unwrap(params), istream);
if (!argsOrError) {
return wrap((concretelang::clientlib::PublicArguments *)NULL,
argsOrError.error().mesg);
}
return wrap(argsOrError.value().release());
}
void publicArgumentsDestroy(PublicArguments publicArgs){
C_STRUCT_CLEANER(publicArgs)}
@@ -563,6 +652,22 @@ LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) {
return wrap(lambdaArgOrError.get().release());
}
BufferRef publicResultSerialize(PublicResult result) {
return serialize(result);
}
PublicResult publicResultUnserialize(BufferRef buffer,
ClientParameters params) {
std::stringstream istream(std::string(buffer.data, buffer.length));
auto resultOrError = concretelang::clientlib::PublicResult::unserialize(
*unwrap(params), istream);
if (!resultOrError) {
return wrap((concretelang::clientlib::PublicResult *)NULL,
resultOrError.error().mesg);
}
return wrap(resultOrError.value().release());
}
void publicResultDestroy(PublicResult publicResult) {
C_STRUCT_CLEANER(publicResult)
}