From 8c6a0859cd54b9b9f11acbfbc724d09479b4a6ce Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 8 Dec 2022 10:00:54 +0100 Subject: [PATCH] refactor(rust): add Rust wrapper for every CStruct we want to wrap CStructs in RustStructs to own them, and free memeory when they are no longer used. Users won't have to deal with the direct binded CAPI, but the new wrappers --- .../concretelang-c/Support/CompilerEngine.h | 18 +- compiler/lib/Bindings/Rust/src/compiler.rs | 1773 ++++++++++------- compiler/lib/CAPI/Support/CompilerEngine.cpp | 19 +- 3 files changed, 1090 insertions(+), 720 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 63eeb1ef1..696244e0b 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -129,16 +129,27 @@ typedef enum CompilationTarget CompilationTarget; /// ********** CompilationOptions CAPI ***************************************** -MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate(); +MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate( + MlirStringRef funcName, bool autoParallelize, bool batchConcreteOps, + bool dataflowParallelize, bool emitGPUOps, bool loopParallelize, + bool optimizeConcrete, OptimizerConfig optimizerConfig, + bool verifyDiagnostics); MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreateDefault(); +MLIR_CAPI_EXPORTED void compilationOptionsDestroy(CompilationOptions options); + /// ********** OptimizerConfig CAPI ******************************************** -MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreate(); +MLIR_CAPI_EXPORTED OptimizerConfig +optimizerConfigCreate(bool display, double fallback_log_norm_woppbs, + double global_p_error, double p_error, uint64_t security, + bool strategy_v0, bool use_gpu_constraints); MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreateDefault(); +MLIR_CAPI_EXPORTED void optimizerConfigDestroy(OptimizerConfig config); + /// ********** CompilerEngine CAPI ********************************************* MLIR_CAPI_EXPORTED CompilerEngine compilerEngineCreate(); @@ -224,6 +235,9 @@ MLIR_CAPI_EXPORTED BufferRef clientParametersSerialize(ClientParameters params); MLIR_CAPI_EXPORTED ClientParameters clientParametersUnserialize(BufferRef buffer); +MLIR_CAPI_EXPORTED ClientParameters +clientParametersCopy(ClientParameters params); + MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params); /// ********** KeySet CAPI ***************************************************** diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index 477277775..a18d5ff79 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -1,6 +1,6 @@ //! Compiler module -use crate::mlir::ffi::*; +use crate::mlir::ffi; use std::os::raw::c_char; use std::{ffi::CStr, path::Path}; @@ -23,54 +23,934 @@ macro_rules! impl_CStructErrorMsg { } } impl_CStructErrorMsg! {[ - crate::mlir::ffi::LibrarySupport, - CompilationResult, - LibraryCompilationResult, - ServerLambda, - ClientParameters, - PublicArguments, - PublicResult, - KeySet, - KeySetCache, - LambdaArgument, - BufferRef, - EvaluationKeys + ffi::BufferRef, + ffi::CompilationOptions, + ffi::OptimizerConfig, + ffi::CompilerEngine, + ffi::CompilationResult, + ffi::Library, + ffi::LibraryCompilationResult, + ffi::LibrarySupport, + ffi::ServerLambda, + ffi::ClientParameters, + ffi::KeySet, + ffi::KeySetCache, + ffi::EvaluationKeys, + ffi::LambdaArgument, + ffi::PublicArguments, + ffi::PublicResult, + ffi::CompilationFeedback ]} /// Construct a rust error message from a buffer in the C struct. -fn get_error_msg_from_ctype(c_struct: T) -> String { +fn get_error_msg_from_ctype(c_struct: &T) -> String { unsafe { let error_msg_cstr = CStr::from_ptr(c_struct.error_msg()); String::from(error_msg_cstr.to_str().unwrap()) } } -/// Create string from an MlirStringRef and free its memory. -/// -/// # SAFETY -/// -/// This should only be used with string refs returned by the compiler. -unsafe fn mlir_string_ref_to_string(str_ref: MlirStringRef) -> String { - let result = String::from_utf8_lossy(std::slice::from_raw_parts( - str_ref.data as *const u8, - str_ref.length as usize, - )) - .to_string(); - mlirStringRefDestroy(str_ref); - result +/// Wrapper to own MlirStringRef coming from the compiler and destroy them on drop +struct MlirStringRef(ffi::MlirStringRef); + +impl MlirStringRef { + pub fn to_string(&self) -> Result { + unsafe { + if self.0.data.is_null() { + return Err(CompilerError("string ref points to null".to_string())); + } + let result = String::from_utf8_lossy(std::slice::from_raw_parts( + self.0.data as *const u8, + self.0.length as usize, + )) + .to_string(); + Ok(result) + } + } + + /// Create an ffi MlirStringRef for a rust str. + /// + /// The reason behind not returning a wrapper is that it would lead to freeing rust memory + /// using a custom destructor in C. + /// + /// # SAFETY + /// The caller has to make sure the &str outlive the ffi::MlirStringRef + pub unsafe fn from_rust_str(s: &str) -> ffi::MlirStringRef { + ffi::MlirStringRef { + data: s.as_ptr() as *const c_char, + length: s.len() as ffi::size_t, + } + } } -/// Create a vector of bytes from a BufferRef and free its memory. +impl Drop for MlirStringRef { + fn drop(&mut self) { + unsafe { ffi::mlirStringRefDestroy(self.0) } + } +} + +trait CStructWrapper { + // wrap a c-struct inside a rust-struct + fn wrap(c_struct: T) -> Self; + // check if the wrapped c-struct is null + fn is_null(&self) -> bool; + // get error message + fn error_msg(&self) -> String; + // drop + fn destroy(&mut self); +} + +/// Wrapper of CStruct. /// -/// # SAFETY -/// -/// This should only be used with string refs returned by the compiler. -unsafe fn buffer_ref_to_bytes(buffer_ref: BufferRef) -> Vec { - let result = - std::slice::from_raw_parts(buffer_ref.data as *const c_char, buffer_ref.length as usize) +/// We want to have a Rust wrapper for every CStruct that will take care of owning +/// it, and freeing memory when it's no longer used. +macro_rules! def_CStructWrapper { + ( + $name:ident => { + $ffi_is_null_fn:ident, + $ffi_destroy_fn:ident + $(,)? + } + ) => { + + pub struct $name{ _c: ffi::$name } + + impl CStructWrapper for $name { + // wrap a c-struct inside a rust-struct + fn wrap(c_struct: ffi::$name) -> Self { + Self{_c: c_struct} + } + // check if the wrapped C-struct is null + fn is_null(&self) -> bool { + unsafe { + ffi::$ffi_is_null_fn(self._c) + } + } + // get error message + fn error_msg(&self) -> String { + get_error_msg_from_ctype(&self._c) + } + // free memory allocated for the C-struct + fn destroy(&mut self) { + unsafe { + ffi::$ffi_destroy_fn(self._c) + } + } + } + + impl Drop for $name { + fn drop(&mut self) { + self.destroy(); + } + } + }; + + ( + $( + $name:ident => { + $ffi_is_null_fn:ident, + $ffi_destroy_fn:ident + $(,)? + } + ),+ + $(,)? + ) => { + $( + def_CStructWrapper!{ + $name => { + $ffi_is_null_fn, + $ffi_destroy_fn + } + } + )+ + }; +} +def_CStructWrapper! { + BufferRef => { + bufferRefIsNull, + bufferRefDestroy + }, + CompilationOptions => { + compilationOptionsIsNull, + compilationOptionsDestroy, + }, + OptimizerConfig => { + optimizerConfigIsNull, + optimizerConfigDestroy, + }, + CompilerEngine => { + compilerEngineIsNull, + compilerEngineDestroy, + }, + CompilationResult => { + compilationResultIsNull, + compilationResultDestroy, + }, + Library => { + libraryIsNull, + libraryDestroy, + }, + LibraryCompilationResult => { + libraryCompilationResultIsNull, + libraryCompilationResultDestroy, + }, + LibrarySupport => { + librarySupportIsNull, + librarySupportDestroy, + }, + ServerLambda => { + serverLambdaIsNull, + serverLambdaDestroy, + }, + ClientParameters => { + clientParametersIsNull, + clientParametersDestroy, + }, + KeySetCache => { + keySetCacheIsNull, + keySetCacheDestroy, + }, + EvaluationKeys => { + evaluationKeysIsNull, + evaluationKeysDestroy, + }, + LambdaArgument => { + lambdaArgumentIsNull, + lambdaArgumentDestroy, + }, + PublicArguments => { + publicArgumentsIsNull, + publicArgumentsDestroy, + }, + PublicResult => { + publicResultIsNull, + publicResultDestroy, + }, + CompilationFeedback => { + compilationFeedbackIsNull, + compilationFeedbackDestroy, + } +} + +impl BufferRef { + /// Create a reference to a buffer in memory. + /// + /// The pointed memory will not get owned. The caller must make sure the pointer points + /// to a valid memory region of the provided length, and that the pointed memory outlive + /// the buffer reference. + pub fn new(ptr: *const c_char, length: ffi::size_t) -> Result { + unsafe { + let buffer_ref = ffi::bufferRefCreate(ptr, length); + if ffi::bufferRefIsNull(buffer_ref) { + let error_msg = get_error_msg_from_ctype(&buffer_ref); + ffi::bufferRefDestroy(buffer_ref); + return Err(CompilerError(error_msg)); + } + return Ok(buffer_ref); + } + } + + /// Copy the content of the buffer into a new vector of bytes. + /// + /// Returns an empty vector if the buffer reference a null pointer. + pub fn to_bytes(&self) -> Vec { + if self.is_null() { + return Vec::new(); + } + let buffer_ref_c = self._c; + unsafe { + let result = std::slice::from_raw_parts( + buffer_ref_c.data as *const c_char, + buffer_ref_c.length as usize, + ) .to_vec(); - bufferRefDestroy(buffer_ref); - result + result + } + } +} + +impl CompilationOptions { + pub fn new( + func_name: &str, + auto_parallelize: bool, + batch_concrete_ops: bool, + dataflow_parallelize: bool, + emit_gpu_ops: bool, + loop_parallelize: bool, + optimize_concrete: bool, + optimizer_config: &OptimizerConfig, + verify_diagnostics: bool, + ) -> Result { + unsafe { + let options = CompilationOptions::wrap(ffi::compilationOptionsCreate( + MlirStringRef::from_rust_str(func_name), + auto_parallelize, + batch_concrete_ops, + dataflow_parallelize, + emit_gpu_ops, + loop_parallelize, + optimize_concrete, + optimizer_config._c, + verify_diagnostics, + )); + if options.is_null() { + return Err(CompilerError(options.error_msg())); + } + Ok(options) + } + } + + pub fn get_default() -> Result { + unsafe { + let options = CompilationOptions::wrap(ffi::compilationOptionsCreateDefault()); + if options.is_null() { + return Err(CompilerError(options.error_msg())); + } + Ok(options) + } + } +} + +impl OptimizerConfig { + pub fn new( + display: bool, + fallback_log_norm_woppbs: f64, + global_p_error: f64, + p_error: f64, + security: u64, + strategy_v0: bool, + use_gpu_constraints: bool, + ) -> Result { + unsafe { + let config = OptimizerConfig::wrap(ffi::optimizerConfigCreate( + display, + fallback_log_norm_woppbs, + global_p_error, + p_error, + security, + strategy_v0, + use_gpu_constraints, + )); + if config.is_null() { + return Err(CompilerError(config.error_msg())); + } + Ok(config) + } + } + + pub fn get_default() -> Result { + unsafe { + let config = OptimizerConfig::wrap(ffi::optimizerConfigCreateDefault()); + if config.is_null() { + return Err(CompilerError(config.error_msg())); + } + Ok(config) + } + } +} +impl CompilerEngine { + pub fn new(options: Option<&CompilationOptions>) -> Result { + unsafe { + let engine = CompilerEngine::wrap(ffi::compilerEngineCreate()); + if engine.is_null() { + return Err(CompilerError(engine.error_msg())); + } + if let Some(o) = options { + engine.set_options(o) + } + Ok(engine) + } + } + + pub fn set_options(&self, options: &CompilationOptions) { + unsafe { + ffi::compilerEngineCompileSetOptions(self._c, options._c); + } + } + + pub fn compile( + &self, + module: &str, + target: ffi::CompilationTarget, + ) -> Result { + unsafe { + let module_string_ref = MlirStringRef::from_rust_str(module); + let result = CompilationResult::wrap(ffi::compilerEngineCompile( + self._c, + module_string_ref, + target, + )); + if result.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + result.error_msg() + ))); + } + Ok(result) + } + } +} +impl CompilationResult { + pub fn get_module_string(&self) -> Result { + unsafe { MlirStringRef(ffi::compilationResultGetModuleString(self._c)).to_string() } + } +} +impl Library { + pub fn new( + output_dir_path: &str, + runtime_library_path: Option<&str>, + clean_up: bool, + ) -> Result { + unsafe { + let lib = Library::wrap(ffi::libraryCreate( + MlirStringRef::from_rust_str(output_dir_path), + MlirStringRef::from_rust_str(runtime_library_path.unwrap_or("")), + clean_up, + )); + if lib.is_null() { + return Err(CompilerError(lib.error_msg())); + } + Ok(lib) + } + } +} + +impl LibraryCompilationResult {} + +/// Support for compiling and executing libraries. +impl LibrarySupport { + /// LibrarySupport manages build files generated by the compiler under the `output_dir_path`. + /// + /// The compiled library needs to link to the runtime for proper execution. + pub fn new( + output_dir_path: &str, + runtime_library_path: Option, + ) -> Result { + unsafe { + let runtime_library_path = match runtime_library_path { + Some(val) => val.to_string(), + None => "".to_string(), + }; + let runtime_library_path_buffer = runtime_library_path.as_str(); + let support = LibrarySupport::wrap(ffi::librarySupportCreateDefault( + MlirStringRef::from_rust_str(output_dir_path), + MlirStringRef::from_rust_str(runtime_library_path_buffer), + )); + if support.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + support.error_msg() + ))); + } + Ok(support) + } + } + + /// Compile an MLIR into a library. + pub fn compile( + &self, + mlir_code: &str, + options: Option, + ) -> Result { + unsafe { + let options = options.unwrap_or_else(|| CompilationOptions::get_default().unwrap()); + let result = LibraryCompilationResult::wrap(ffi::librarySupportCompile( + self._c, + MlirStringRef::from_rust_str(mlir_code), + options._c, + )); + if result.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + result.error_msg() + ))); + } + Ok(result) + } + } + + /// Load server lambda from a compilation result. + /// + /// This can be used for executing the compiled function. + pub fn load_server_lambda( + &self, + result: &LibraryCompilationResult, + ) -> Result { + unsafe { + let server = + ServerLambda::wrap(ffi::librarySupportLoadServerLambda(self._c, result._c)); + if server.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + server.error_msg() + ))); + } + Ok(server) + } + } + + /// Load client parameters from a compilation result. + /// + /// This can be used for creating keys for the compiled library. + pub fn load_client_parameters( + &self, + result: &LibraryCompilationResult, + ) -> Result { + unsafe { + let params = + ClientParameters::wrap(ffi::librarySupportLoadClientParameters(self._c, result._c)); + if params.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + params.error_msg() + ))); + } + Ok(params) + } + } + + /// Run a compiled circuit. + pub fn server_lambda_call( + &self, + server_lambda: &ServerLambda, + args: &PublicArguments, + eval_keys: &EvaluationKeys, + ) -> Result { + unsafe { + let result = PublicResult::wrap(ffi::librarySupportServerCall( + self._c, + server_lambda._c, + args._c, + eval_keys._c, + )); + if result.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + result.error_msg() + ))); + } + Ok(result) + } + } + + /// Get path to the compiled shared library + pub fn get_shared_lib_path(&self) -> String { + unsafe { + MlirStringRef(ffi::librarySupportGetSharedLibPath(self._c)) + .to_string() + .unwrap() + } + } + + /// Get path to the client parameters + pub fn get_client_parameters_path(&self) -> String { + unsafe { + MlirStringRef(ffi::librarySupportGetClientParametersPath(self._c)) + .to_string() + .unwrap() + } + } +} + +impl ServerLambda {} + +impl ClientParameters { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = BufferRef::wrap(ffi::clientParametersSerialize(self._c)); + if serialized_ref.is_null() { + return Err(CompilerError(serialized_ref.error_msg())); + } + Ok(serialized_ref.to_bytes()) + } + } + pub fn unserialize(serialized: &Vec) -> Result { + unsafe { + let serialized_ref = BufferRef::new( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ) + .unwrap(); + let params = ClientParameters::wrap(ffi::clientParametersUnserialize(serialized_ref)); + if params.is_null() { + return Err(CompilerError(params.error_msg())); + } + Ok(params) + } + } +} + +impl Clone for ClientParameters { + fn clone(&self) -> Self { + unsafe { ClientParameters::wrap(ffi::clientParametersCopy(self._c)) } + } +} + +struct KeySet_ { + _c: ffi::KeySet, +} + +impl CStructWrapper for KeySet_ { + // wrap a c-struct inside a rust-struct + fn wrap(c_struct: ffi::KeySet) -> KeySet_ { + KeySet_ { _c: c_struct } + } + // check if the wrapped C-struct is null + fn is_null(&self) -> bool { + unsafe { ffi::keySetIsNull(self._c) } + } + // get error message + fn error_msg(&self) -> String { + get_error_msg_from_ctype(&self._c) + } + // free memory allocated for the C-struct + fn destroy(&mut self) { + unsafe { ffi::keySetDestroy(self._c) } + } +} + +impl Drop for KeySet_ { + fn drop(&mut self) { + self.destroy(); + } +} +pub struct KeySet { + key_set: KeySet_, + client_params: ClientParameters, +} + +impl KeySet { + /// Get a keyset based on the client parameters, and the different seeds. + /// + /// If a cache is set, this operation would first try to load an existing key, + /// otherwise, a new keyset will be generated. + pub fn new( + client_params: &ClientParameters, + seed_msb: Option, + seed_lsb: Option, + key_set_cache: Option<&KeySetCache>, + ) -> Result { + unsafe { + let key_set = match key_set_cache { + Some(cache) => KeySet_::wrap(ffi::keySetCacheLoadOrGenerateKeySet( + cache._c, + client_params._c, + seed_msb.unwrap_or(0), + seed_lsb.unwrap_or(0), + )), + None => KeySet_::wrap(ffi::keySetGenerate( + client_params._c, + seed_msb.unwrap_or(0), + seed_lsb.unwrap_or(0), + )), + }; + if key_set.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + key_set.error_msg() + ))); + } + Ok(KeySet { + key_set, + client_params: client_params.clone(), + }) + } + } + + pub fn get_evaluation_keys(&self) -> Result { + unsafe { + let eval_keys = EvaluationKeys::wrap(ffi::keySetGetEvaluationKeys(self.key_set._c)); + if eval_keys.is_null() { + return Err(CompilerError(eval_keys.error_msg())); + } + Ok(eval_keys) + } + } + + /// Encrypt arguments of a compiled circuit. + pub fn encrypt_args(&self, args: &[LambdaArgument]) -> Result { + LambdaArgument::encrypt_args(args, self) + } + + pub fn decrypt_result(&self, result: &PublicResult) -> Result { + result.decrypt(self) + } +} + +impl KeySetCache { + pub fn new(path: &Path) -> Result { + unsafe { + let cache_path_buffer = path.to_str().unwrap(); + let cache = KeySetCache::wrap(ffi::keySetCacheCreate(MlirStringRef::from_rust_str( + cache_path_buffer, + ))); + if cache.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + cache.error_msg() + ))); + } + Ok(cache) + } + } +} + +impl EvaluationKeys { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = BufferRef::wrap(ffi::evaluationKeysSerialize(self._c)); + if serialized_ref.is_null() { + return Err(CompilerError(serialized_ref.error_msg())); + } + Ok(serialized_ref.to_bytes()) + } + } + pub fn unserialize(serialized: &Vec) -> Result { + unsafe { + let serialized_ref = BufferRef::new( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ) + .unwrap(); + let eval_keys = EvaluationKeys::wrap(ffi::evaluationKeysUnserialize(serialized_ref)); + if eval_keys.is_null() { + return Err(CompilerError(eval_keys.error_msg())); + } + Ok(eval_keys) + } + } +} + +impl LambdaArgument { + pub fn encrypt_args( + args: &[LambdaArgument], + key_set: &KeySet, + ) -> Result { + unsafe { + let args: Vec = args.into_iter().map(|a| a._c).collect(); + let public_args = PublicArguments::wrap(ffi::lambdaArgumentEncrypt( + args.as_ptr(), + args.len() as u64, + key_set.client_params._c, + key_set.key_set._c, + )); + if public_args.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + public_args.error_msg() + ))); + } + Ok(public_args) + } + } + + pub fn from_scalar(scalar: u64) -> Result { + unsafe { + let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromScalar(scalar)); + if arg.is_null() { + return Err(CompilerError(arg.error_msg())); + } + Ok(arg) + } + } + + pub fn is_scalar(&self) -> bool { + unsafe { ffi::lambdaArgumentIsScalar(self._c) } + } + + pub fn get_scalar(&self) -> Result { + unsafe { + if !self.is_scalar() { + return Err(CompilerError("argument is not a scalar".to_string())); + } + Ok(ffi::lambdaArgumentGetScalar(self._c)) + } + } + + pub fn from_tensor_u8(data: &[u8], dims: &[i64]) -> Result { + unsafe { + let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromTensorU8( + data.as_ptr(), + dims.as_ptr(), + dims.len().try_into().unwrap(), + )); + if arg.is_null() { + return Err(CompilerError(arg.error_msg())); + } + Ok(arg) + } + } + + pub fn from_tensor_u64(data: &[u64], dims: &[i64]) -> Result { + unsafe { + let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromTensorU64( + data.as_ptr(), + dims.as_ptr(), + dims.len().try_into().unwrap(), + )); + if arg.is_null() { + return Err(CompilerError(arg.error_msg())); + } + Ok(arg) + } + } + + pub fn is_tensor(&self) -> bool { + unsafe { ffi::lambdaArgumentIsTensor(self._c) } + } + + pub fn get_data_size(&self) -> Result { + unsafe { + if !self.is_tensor() { + return Err(CompilerError("argument is not a tensor".to_string())); + } + Ok(ffi::lambdaArgumentGetTensorDataSize(self._c)) + } + } + + pub fn get_rank(&self) -> Result { + unsafe { + if !self.is_tensor() { + return Err(CompilerError("argument is not a tensor".to_string())); + } + Ok(ffi::lambdaArgumentGetTensorRank(self._c)) + } + } + + pub fn get_dims(&self) -> Result, CompilerError> { + unsafe { + let rank = self.get_rank().unwrap(); + let mut dims = Vec::new(); + dims.resize(rank.try_into().unwrap(), 0); + if !ffi::lambdaArgumentGetTensorDims(self._c, dims.as_mut_ptr()) { + return Err(CompilerError("couldn't get dims".to_string())); + } + Ok(dims) + } + } + + pub fn get_data(&self) -> Result, CompilerError> { + unsafe { + let size = self.get_data_size().unwrap(); + let mut data = Vec::new(); + data.resize(size.try_into().unwrap(), 0); + if !ffi::lambdaArgumentGetTensorData(self._c, data.as_mut_ptr()) { + return Err(CompilerError("couldn't get data".to_string())); + } + Ok(data) + } + } +} + +impl PublicArguments { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = BufferRef::wrap(ffi::publicArgumentsSerialize(self._c)); + if serialized_ref.is_null() { + return Err(CompilerError(serialized_ref.error_msg())); + } + Ok(serialized_ref.to_bytes()) + } + } + pub fn unserialize( + serialized: &Vec, + client_parameters: &ClientParameters, + ) -> Result { + unsafe { + let serialized_ref = BufferRef::new( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ) + .unwrap(); + let public_args = PublicArguments::wrap(ffi::publicArgumentsUnserialize( + serialized_ref, + client_parameters._c, + )); + if public_args.is_null() { + return Err(CompilerError(public_args.error_msg())); + } + Ok(public_args) + } + } +} + +impl PublicResult { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = BufferRef::wrap(ffi::publicResultSerialize(self._c)); + if serialized_ref.is_null() { + return Err(CompilerError(serialized_ref.error_msg())); + } + Ok(serialized_ref.to_bytes()) + } + } + pub fn unserialize( + serialized: &Vec, + client_parameters: &ClientParameters, + ) -> Result { + unsafe { + let serialized_ref = BufferRef::new( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ) + .unwrap(); + let public_result = PublicResult::wrap(ffi::publicResultUnserialize( + serialized_ref, + client_parameters._c, + )); + if public_result.is_null() { + return Err(CompilerError(public_result.error_msg())); + } + Ok(public_result) + } + } + + pub fn decrypt(&self, key_set: &KeySet) -> Result { + unsafe { + let arg = LambdaArgument::wrap(ffi::publicResultDecrypt(self._c, key_set.key_set._c)); + if arg.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + arg.error_msg() + ))); + } + Ok(arg) + } + } +} + +impl CompilationFeedback { + pub fn get_complexity(&self) -> f64 { + unsafe { ffi::compilationFeedbackGetComplexity(self._c) } + } + + pub fn get_p_error(&self) -> f64 { + unsafe { ffi::compilationFeedbackGetPError(self._c) } + } + + pub fn get_global_p_error(&self) -> f64 { + unsafe { ffi::compilationFeedbackGetGlobalPError(self._c) } + } + + pub fn get_total_secret_keys_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalSecretKeysSize(self._c) } + } + + pub fn get_total_bootstrap_keys_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalBootstrapKeysSize(self._c) } + } + + pub fn get_total_keyswitch_keys_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalKeyswitchKeysSize(self._c) } + } + + pub fn get_total_inputs_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalInputsSize(self._c) } + } + + pub fn get_total_outputs_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalOutputsSize(self._c) } + } } /// Parse the MLIR code and returns it. @@ -91,479 +971,9 @@ unsafe fn buffer_ref_to_bytes(buffer_ref: BufferRef) -> Vec { /// ``` /// pub fn round_trip(mlir_code: &str) -> Result { - unsafe { - let engine = compilerEngineCreate(); - let mlir_code_buffer = mlir_code.as_bytes(); - let compilation_result = compilerEngineCompile( - engine, - MlirStringRef { - data: mlir_code_buffer.as_ptr() as *const c_char, - length: mlir_code_buffer.len() as size_t, - }, - CompilationTarget_ROUND_TRIP, - ); - if compilationResultIsNull(compilation_result) { - let error_msg = get_error_msg_from_ctype(compilation_result); - compilationResultDestroy(compilation_result); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - let module_compiled = compilationResultGetModuleString(compilation_result); - let result_str = mlir_string_ref_to_string(module_compiled); - compilerEngineDestroy(engine); - Ok(result_str) - } -} - -/// Support for compiling and executing libraries. -pub struct LibrarySupport { - support: crate::mlir::ffi::LibrarySupport, -} - -impl Drop for LibrarySupport { - fn drop(&mut self) { - unsafe { - librarySupportDestroy(self.support); - } - } -} - -impl LibrarySupport { - /// LibrarySupport manages build files generated by the compiler under the `output_dir_path`. - /// - /// The compiled library needs to link to the runtime for proper execution. - pub fn new( - output_dir_path: &str, - runtime_library_path: &str, - ) -> Result { - unsafe { - let output_dir_path_buffer = output_dir_path.as_bytes(); - let runtime_library_path_buffer = runtime_library_path.as_bytes(); - let support = librarySupportCreateDefault( - MlirStringRef { - data: output_dir_path_buffer.as_ptr() as *const c_char, - length: output_dir_path_buffer.len() as size_t, - }, - MlirStringRef { - data: runtime_library_path_buffer.as_ptr() as *const c_char, - length: runtime_library_path_buffer.len() as size_t, - }, - ); - if librarySupportIsNull(support) { - let error_msg = get_error_msg_from_ctype(support); - librarySupportDestroy(support); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(LibrarySupport { support }) - } - } - - /// Compile an MLIR into a library. - pub fn compile( - &self, - mlir_code: &str, - options: Option, - ) -> Result { - unsafe { - let options = options.unwrap_or_else(|| compilationOptionsCreateDefault()); - let mlir_code_buffer = mlir_code.as_bytes(); - let result = librarySupportCompile( - self.support, - MlirStringRef { - data: mlir_code_buffer.as_ptr() as *const c_char, - length: mlir_code_buffer.len() as size_t, - }, - options, - ); - if libraryCompilationResultIsNull(result) { - let error_msg = get_error_msg_from_ctype(result); - libraryCompilationResultDestroy(result); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(result) - } - } - - /// Load server lambda from a compilation result. - /// - /// This can be used for executing the compiled function. - pub fn load_server_lambda( - &self, - result: LibraryCompilationResult, - ) -> Result { - unsafe { - let server = librarySupportLoadServerLambda(self.support, result); - if serverLambdaIsNull(server) { - let error_msg = get_error_msg_from_ctype(server); - serverLambdaDestroy(server); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(server) - } - } - - /// Load client parameters from a compilation result. - /// - /// This can be used for creating keys for the compiled library. - pub fn load_client_parameters( - &self, - result: LibraryCompilationResult, - ) -> Result { - unsafe { - let params = librarySupportLoadClientParameters(self.support, result); - if clientParametersIsNull(params) { - let error_msg = get_error_msg_from_ctype(params); - clientParametersDestroy(params); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(params) - } - } - - /// Run a compiled circuit. - pub fn server_lambda_call( - &self, - server_lambda: ServerLambda, - args: PublicArguments, - eval_keys: EvaluationKeys, - ) -> Result { - unsafe { - let result = librarySupportServerCall(self.support, server_lambda, args, eval_keys); - if publicResultIsNull(result) { - let error_msg = get_error_msg_from_ctype(result); - publicResultDestroy(result); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(result) - } - } - - /// Get path to the compiled shared library - pub fn get_shared_lib_path(&self) -> String { - unsafe { mlir_string_ref_to_string(librarySupportGetSharedLibPath(self.support)) } - } - - /// Get path to the client parameters - pub fn get_client_parameters_path(&self) -> String { - unsafe { mlir_string_ref_to_string(librarySupportGetClientParametersPath(self.support)) } - } -} - -/// Support for keygen, encryption, and decryption. -/// -/// Manages cache for keys if provided during creation. -pub struct ClientSupport { - client_params: crate::mlir::ffi::ClientParameters, - key_set_cache: Option, -} - -impl Drop for ClientSupport { - fn drop(&mut self) { - unsafe { - clientParametersDestroy(self.client_params); - match self.key_set_cache { - Some(cache) => keySetCacheDestroy(cache), - None => (), - } - } - } -} - -impl ClientSupport { - pub fn new( - client_params: ClientParameters, - key_set_cache_path: Option<&Path>, - ) -> Result { - unsafe { - let key_set_cache = match key_set_cache_path { - Some(path) => { - let cache_path_buffer = path.to_str().unwrap().as_bytes(); - let cache = keySetCacheCreate(MlirStringRef { - data: cache_path_buffer.as_ptr() as *const c_char, - length: cache_path_buffer.len() as size_t, - }); - if keySetCacheIsNull(cache) { - let error_msg = get_error_msg_from_ctype(cache); - keySetCacheDestroy(cache); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Some(cache) - } - None => None, - }; - Ok(ClientSupport { - client_params, - key_set_cache, - }) - } - } - - /// Fetch a keyset based on the client parameters, and the different seeds. - /// - /// If a cache has already been set, this operation would first try to load an existing key, - /// and generate a new one if no compatible keyset exists. - pub fn keyset( - &self, - seed_msb: Option, - seed_lsb: Option, - ) -> Result { - unsafe { - let key_set = match self.key_set_cache { - Some(cache) => keySetCacheLoadOrGenerateKeySet( - cache, - self.client_params, - seed_msb.unwrap_or(0), - seed_lsb.unwrap_or(0), - ), - None => keySetGenerate( - self.client_params, - seed_msb.unwrap_or(0), - seed_lsb.unwrap_or(0), - ), - }; - if keySetIsNull(key_set) { - let error_msg = get_error_msg_from_ctype(key_set); - keySetDestroy(key_set); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(key_set) - } - } - - /// Encrypt arguments of a compiled circuit. - pub fn encrypt_args( - &self, - args: &[LambdaArgument], - key_set: KeySet, - ) -> Result { - unsafe { - let public_args = lambdaArgumentEncrypt( - args.as_ptr(), - args.len() as u64, - self.client_params, - key_set, - ); - if publicArgumentsIsNull(public_args) { - let error_msg = get_error_msg_from_ctype(public_args); - publicArgumentsDestroy(public_args); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(public_args) - } - } - - pub fn decrypt_result( - &self, - result: PublicResult, - key_set: KeySet, - ) -> Result { - unsafe { - let arg = publicResultDecrypt(result, key_set); - if lambdaArgumentIsNull(arg) { - let error_msg = get_error_msg_from_ctype(arg); - lambdaArgumentDestroy(arg); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(arg) - } - } -} - -// TODO: implement traits for C Struct that are serializable and reduce code for serialization and maybe refactor other functions. -// destroy and is_null could be implemented for other struct as well. -// -// trait Serializable { -// fn into_buffer_ref(self) -> BufferRef; -// fn from_buffer_ref(buff: BufferRef, params: Option) -> Self; -// fn is_null(self) -> bool; -// fn destroy(self); -// } - -// fn serialize(to_serialize: T) -> Result, CompilerError> { -// unsafe { -// let serialized_ref = to_serialize.into_buffer_ref(); -// if bufferRefIsNull(serialized_ref) { -// let error_msg = get_error_msg_from_ctype(serialized_ref); -// bufferRefDestroy(serialized_ref); -// return Err(CompilerError(error_msg)); -// } -// let serialized = buffer_ref_to_bytes(serialized_ref); -// Ok(serialized) -// } -// } - -// fn unserialize( -// serialized: &Vec, -// client_parameters: Option, -// ) -> Result { -// unsafe { -// let serialized_ref = bufferRefCreate( -// serialized.as_ptr() as *const c_char, -// serialized.len().try_into().unwrap(), -// ); -// let serialized = T::from_buffer_ref(serialized_ref, client_parameters); -// if serialized.is_null() { -// let error_msg = get_error_msg_from_ctype(serialized); -// serialized.destroy(); -// return Err(CompilerError(error_msg)); -// } -// Ok(serialized) -// } -// } - -impl PublicArguments { - pub fn serialize(self) -> Result, CompilerError> { - unsafe { - let serialized_ref = publicArgumentsSerialize(self); - if bufferRefIsNull(serialized_ref) { - let error_msg = get_error_msg_from_ctype(serialized_ref); - bufferRefDestroy(serialized_ref); - return Err(CompilerError(error_msg)); - } - let serialized = buffer_ref_to_bytes(serialized_ref); - Ok(serialized) - } - } - pub fn unserialize( - serialized: &Vec, - client_parameters: ClientParameters, - ) -> Result { - unsafe { - let serialized_ref = bufferRefCreate( - serialized.as_ptr() as *const c_char, - serialized.len().try_into().unwrap(), - ); - let public_args = publicArgumentsUnserialize(serialized_ref, client_parameters); - if publicArgumentsIsNull(public_args) { - let error_msg = get_error_msg_from_ctype(public_args); - publicArgumentsDestroy(public_args); - return Err(CompilerError(error_msg)); - } - Ok(public_args) - } - } -} - -impl PublicResult { - pub fn serialize(self) -> Result, CompilerError> { - unsafe { - let serialized_ref = publicResultSerialize(self); - if bufferRefIsNull(serialized_ref) { - let error_msg = get_error_msg_from_ctype(serialized_ref); - bufferRefDestroy(serialized_ref); - return Err(CompilerError(error_msg)); - } - let serialized = buffer_ref_to_bytes(serialized_ref); - Ok(serialized) - } - } - pub fn unserialize( - serialized: &Vec, - client_parameters: ClientParameters, - ) -> Result { - unsafe { - let serialized_ref = bufferRefCreate( - serialized.as_ptr() as *const c_char, - serialized.len().try_into().unwrap(), - ); - let public_result = publicResultUnserialize(serialized_ref, client_parameters); - if publicResultIsNull(public_result) { - let error_msg = get_error_msg_from_ctype(public_result); - publicResultDestroy(public_result); - return Err(CompilerError(error_msg)); - } - Ok(public_result) - } - } -} - -impl EvaluationKeys { - pub fn serialize(self) -> Result, CompilerError> { - unsafe { - let serialized_ref = evaluationKeysSerialize(self); - if bufferRefIsNull(serialized_ref) { - let error_msg = get_error_msg_from_ctype(serialized_ref); - bufferRefDestroy(serialized_ref); - return Err(CompilerError(error_msg)); - } - let serialized = buffer_ref_to_bytes(serialized_ref); - Ok(serialized) - } - } - pub fn unserialize(serialized: &Vec) -> Result { - unsafe { - let serialized_ref = bufferRefCreate( - serialized.as_ptr() as *const c_char, - serialized.len().try_into().unwrap(), - ); - let eval_keys = evaluationKeysUnserialize(serialized_ref); - if evaluationKeysIsNull(eval_keys) { - let error_msg = get_error_msg_from_ctype(eval_keys); - evaluationKeysDestroy(eval_keys); - return Err(CompilerError(error_msg)); - } - Ok(eval_keys) - } - } -} - -impl ClientParameters { - pub fn serialize(self) -> Result, CompilerError> { - unsafe { - let serialized_ref = clientParametersSerialize(self); - if bufferRefIsNull(serialized_ref) { - let error_msg = get_error_msg_from_ctype(serialized_ref); - bufferRefDestroy(serialized_ref); - return Err(CompilerError(error_msg)); - } - let serialized = buffer_ref_to_bytes(serialized_ref); - Ok(serialized) - } - } - pub fn unserialize(serialized: &Vec) -> Result { - unsafe { - let serialized_ref = bufferRefCreate( - serialized.as_ptr() as *const c_char, - serialized.len().try_into().unwrap(), - ); - let params = clientParametersUnserialize(serialized_ref); - if clientParametersIsNull(params) { - let error_msg = get_error_msg_from_ctype(params); - clientParametersDestroy(params); - return Err(CompilerError(error_msg)); - } - Ok(params) - } - } + let engine = CompilerEngine::new(None).unwrap(); + let compilation_result = engine.compile(mlir_code, ffi::CompilationTarget_ROUND_TRIP)?; + compilation_result.get_module_string() } #[cfg(test)] @@ -573,6 +983,13 @@ mod test { use super::*; + fn get_runtime_lib_path() -> Option { + match env::var("CONCRETE_COMPILER_INSTALL_DIR") { + Ok(val) => Some(val + "/lib/libConcretelangRuntime.so"), + Err(_e) => None, + } + } + #[test] fn test_compiler_round_trip() { let module_to_compile = " @@ -602,262 +1019,192 @@ mod test { #[test] fn test_compiler_compile_lib() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = TempDir::new("rust_test_compiler_compile_lib").unwrap(); - let support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) - .unwrap(); - let lib = support.compile(module_to_compile, None).unwrap(); - assert!(!libraryCompilationResultIsNull(lib)); - libraryCompilationResultDestroy(lib); - // the sharedlib should be enough as a sign that the compilation worked - assert!(Path::new(support.get_shared_lib_path().as_str()).exists()); - assert!(Path::new(support.get_client_parameters_path().as_str()).exists()); - } + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + let lib = support.compile(module_to_compile, None).unwrap(); + assert!(!lib.is_null()); + // the sharedlib should be enough as a sign that the compilation worked + assert!(Path::new(support.get_shared_lib_path().as_str()).exists()); + assert!(Path::new(support.get_client_parameters_path().as_str()).exists()); } /// We want to make sure setting a pointer to null in rust passes the nullptr check in C/Cpp #[test] fn test_compiler_null_ptr_compatibility() { unsafe { - let lib = Library { + let lib = ffi::Library { ptr: std::ptr::null_mut(), error: std::ptr::null_mut(), }; - assert!(libraryIsNull(lib)); + assert!(ffi::libraryIsNull(lib)); } } #[test] fn test_compiler_load_server_lambda_and_client_parameters() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = TempDir::new("rust_test_compiler_load_server_lambda").unwrap(); - let support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) - .unwrap(); - let result = support.compile(module_to_compile, None).unwrap(); - let server = support.load_server_lambda(result).unwrap(); - assert!(!serverLambdaIsNull(server)); - serverLambdaDestroy(server); - let client_params = support.load_client_parameters(result).unwrap(); - assert!(!clientParametersIsNull(client_params)); - libraryCompilationResultDestroy(result); - } + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + let result = support.compile(module_to_compile, None).unwrap(); + let server = support.load_server_lambda(&result).unwrap(); + assert!(!server.is_null()); + let client_params = support.load_client_parameters(&result).unwrap(); + assert!(!client_params.is_null()); } #[test] fn test_compiler_compile_and_exec_scalar_args() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_scalar_args").unwrap(); - let lib_support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let lib_support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(&result).unwrap(); + let client_params = lib_support.load_client_parameters(&result).unwrap(); + let key_set = KeySet::new(&client_params, None, None, None).unwrap(); + let eval_keys = key_set.get_evaluation_keys().unwrap(); + // build lambda arguments from scalar and encrypt them + let args = [ + LambdaArgument::from_scalar(4).unwrap(), + LambdaArgument::from_scalar(2).unwrap(), + ]; + let encrypted_args = key_set.encrypt_args(&args).unwrap(); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(&server_lambda, &encrypted_args, &eval_keys) .unwrap(); - // compile - let result = lib_support.compile(module_to_compile, None).unwrap(); - // loading materials from compilation - // - server_lambda: used for execution - // - client_parameters: used for keygen, encryption, and evaluation keys - let server_lambda = lib_support.load_server_lambda(result).unwrap(); - let client_params = lib_support.load_client_parameters(result).unwrap(); - let client_support = ClientSupport::new(client_params, None).unwrap(); - let key_set = client_support.keyset(None, None).unwrap(); - let eval_keys = keySetGetEvaluationKeys(key_set); - // build lambda arguments from scalar and encrypt them - let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)]; - let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); - // free args - args.map(|arg| lambdaArgumentDestroy(arg)); - // execute the compiled function on the encrypted arguments - let encrypted_result = lib_support - .server_lambda_call(server_lambda, encrypted_args, eval_keys) - .unwrap(); - // decrypt the result of execution - let result_arg = client_support - .decrypt_result(encrypted_result, key_set) - .unwrap(); - // get the scalar value from the result lambda argument - let result = lambdaArgumentGetScalar(result_arg); - assert_eq!(result, 6); - } + // decrypt the result of execution + let result_arg = key_set.decrypt_result(&encrypted_result).unwrap(); + // get the scalar value from the result lambda argument + let result = result_arg.get_scalar().unwrap(); + assert_eq!(result, 6); } #[test] fn test_compiler_compile_and_exec_with_serialization() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = - TempDir::new("rust_test_compiler_compile_and_exec_with_serialization").unwrap(); - let lib_support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let lib_support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(&result).unwrap(); + let client_params = lib_support.load_client_parameters(&result).unwrap(); + // serialize client parameters + let serialized_params = client_params.serialize().unwrap(); + let client_params = ClientParameters::unserialize(&serialized_params).unwrap(); + // generate keys + let key_set = KeySet::new(&client_params, None, None, None).unwrap(); + let eval_keys = key_set.get_evaluation_keys().unwrap(); + // serialize eval keys + let serialized_eval_keys = eval_keys.serialize().unwrap(); + let eval_keys = EvaluationKeys::unserialize(&serialized_eval_keys).unwrap(); + // build lambda arguments from scalar and encrypt them + let args = [ + LambdaArgument::from_scalar(4).unwrap(), + LambdaArgument::from_scalar(2).unwrap(), + ]; + let encrypted_args = key_set.encrypt_args(&args).unwrap(); + // serialize args + let serialized_encrypted_args = encrypted_args.serialize().unwrap(); + let encrypted_args = + PublicArguments::unserialize(&serialized_encrypted_args, &client_params).unwrap(); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(&server_lambda, &encrypted_args, &eval_keys) .unwrap(); - // compile - let result = lib_support.compile(module_to_compile, None).unwrap(); - // loading materials from compilation - // - server_lambda: used for execution - // - client_parameters: used for keygen, encryption, and evaluation keys - let server_lambda = lib_support.load_server_lambda(result).unwrap(); - let client_params = lib_support.load_client_parameters(result).unwrap(); - // serialize client parameters - let serialized_params = client_params.serialize().unwrap(); - let client_params = ClientParameters::unserialize(&serialized_params).unwrap(); - // create client support - let client_support = ClientSupport::new(client_params, None).unwrap(); - let key_set = client_support.keyset(None, None).unwrap(); - let eval_keys = keySetGetEvaluationKeys(key_set); - // serialize eval keys - let serialized_eval_keys = eval_keys.serialize().unwrap(); - let eval_keys = EvaluationKeys::unserialize(&serialized_eval_keys).unwrap(); - // build lambda arguments from scalar and encrypt them - let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)]; - let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); - // free args - args.map(|arg| lambdaArgumentDestroy(arg)); - // serialize args - let serialized_encrypted_args = encrypted_args.serialize().unwrap(); - let encrypted_args = - PublicArguments::unserialize(&serialized_encrypted_args, client_params).unwrap(); - // execute the compiled function on the encrypted arguments - let encrypted_result = lib_support - .server_lambda_call(server_lambda, encrypted_args, eval_keys) - .unwrap(); - // serialize result - let serialized_encrypted_result = encrypted_result.serialize().unwrap(); - let encrypted_result = - PublicResult::unserialize(&serialized_encrypted_result, client_params).unwrap(); - // decrypt the result of execution - let result_arg = client_support - .decrypt_result(encrypted_result, key_set) - .unwrap(); - // get the scalar value from the result lambda argument - let result = lambdaArgumentGetScalar(result_arg); - assert_eq!(result, 6); - } + // serialize result + let serialized_encrypted_result = encrypted_result.serialize().unwrap(); + let encrypted_result = + PublicResult::unserialize(&serialized_encrypted_result, &client_params).unwrap(); + // decrypt the result of execution + let result_arg = key_set.decrypt_result(&encrypted_result).unwrap(); + // get the scalar value from the result lambda argument + let result = result_arg.get_scalar().unwrap(); + assert_eq!(result, 6); } #[test] fn test_tensor_lambda_argument() { - unsafe { - let tensor_data = [1, 2, 3, 73u64]; - let tensor_dims = [2, 2i64]; - let tensor_arg = - lambdaArgumentFromTensorU64(tensor_data.as_ptr(), tensor_dims.as_ptr(), 2); - assert!(!lambdaArgumentIsNull(tensor_arg)); - assert!(!lambdaArgumentIsScalar(tensor_arg)); - assert!(lambdaArgumentIsTensor(tensor_arg)); - assert_eq!(lambdaArgumentGetTensorRank(tensor_arg), 2); - assert_eq!(lambdaArgumentGetTensorDataSize(tensor_arg), 4); - let mut dims: [i64; 2] = [0, 0]; - assert_eq!( - lambdaArgumentGetTensorDims(tensor_arg, dims.as_mut_ptr()), - true - ); - assert_eq!(dims, tensor_dims); - - let mut data: [u64; 4] = [0; 4]; - assert_eq!( - lambdaArgumentGetTensorData(tensor_arg, data.as_mut_ptr()), - true - ); - assert_eq!(data, tensor_data); - lambdaArgumentDestroy(tensor_arg); - } + let tensor_data = [1, 2, 3, 73u64]; + let tensor_dims = [2, 2i64]; + let tensor_arg = LambdaArgument::from_tensor_u64(&tensor_data, &tensor_dims).unwrap(); + assert!(!tensor_arg.is_null()); + assert!(!tensor_arg.is_scalar()); + assert!(tensor_arg.is_tensor()); + assert_eq!(tensor_arg.get_rank().unwrap(), 2); + assert_eq!(tensor_arg.get_data_size().unwrap(), 4); + assert_eq!(tensor_arg.get_dims().unwrap(), tensor_dims); + assert_eq!(tensor_arg.get_data().unwrap(), tensor_data); } #[test] fn test_compiler_compile_and_exec_tensor_args() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: tensor<2x3x!FHE.eint<5>>, %arg1: tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> { %0 = \"FHELinalg.add_eint\"(%arg0, %arg1) : (tensor<2x3x!FHE.eint<5>>, tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> return %0 : tensor<2x3x!FHE.eint<5>> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_tensor_args").unwrap(); - let lib_support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let lib_support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(&result).unwrap(); + let client_params = lib_support.load_client_parameters(&result).unwrap(); + let key_set = KeySet::new(&client_params, None, None, None).unwrap(); + let eval_keys = key_set.get_evaluation_keys().unwrap(); + // build lambda arguments from scalar and encrypt them + let args = [ + LambdaArgument::from_tensor_u8(&[1, 2, 3, 4, 5, 6], &[2, 3]).unwrap(), + LambdaArgument::from_tensor_u8(&[1, 4, 7, 4, 2, 9], &[2, 3]).unwrap(), + ]; + let encrypted_args = key_set.encrypt_args(&args).unwrap(); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(&server_lambda, &encrypted_args, &eval_keys) .unwrap(); - // compile - let result = lib_support.compile(module_to_compile, None).unwrap(); - // loading materials from compilation - // - server_lambda: used for execution - // - client_parameters: used for keygen, encryption, and evaluation keys - let server_lambda = lib_support.load_server_lambda(result).unwrap(); - let client_params = lib_support.load_client_parameters(result).unwrap(); - let client_support = ClientSupport::new(client_params, None).unwrap(); - let key_set = client_support.keyset(None, None).unwrap(); - let eval_keys = keySetGetEvaluationKeys(key_set); - // build lambda arguments from scalar and encrypt them - let args = [ - lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_ptr(), [2, 3].as_ptr(), 2), - lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_ptr(), [2, 3].as_ptr(), 2), - ]; - let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); - // execute the compiled function on the encrypted arguments - let encrypted_result = lib_support - .server_lambda_call(server_lambda, encrypted_args, eval_keys) - .unwrap(); - // decrypt the result of execution - let result_arg = client_support - .decrypt_result(encrypted_result, key_set) - .unwrap(); - // check the tensor dims value from the result lambda argument - assert_eq!(lambdaArgumentGetTensorRank(result_arg), 2); - assert_eq!(lambdaArgumentGetTensorDataSize(result_arg), 6); - let mut dims = [0, 0]; - assert!(lambdaArgumentGetTensorDims(result_arg, dims.as_mut_ptr())); - assert_eq!(dims, [2, 3]); - // check the tensor data from the result lambda argument - let mut data = [0; 6]; - assert!(lambdaArgumentGetTensorData(result_arg, data.as_mut_ptr())); - assert_eq!(data, [2, 6, 10, 8, 7, 15]); - } + // decrypt the result of execution + let result_arg = key_set.decrypt_result(&encrypted_result).unwrap(); + // check the tensor dims value from the result lambda argument + assert_eq!(result_arg.get_rank().unwrap(), 2); + assert_eq!(result_arg.get_data_size().unwrap(), 6); + assert_eq!(result_arg.get_dims().unwrap(), [2, 3]); + // check the tensor data from the result lambda argument + assert_eq!(result_arg.get_data().unwrap(), [2, 6, 10, 8, 7, 15]); } } diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 2adbe6a0b..c90840328 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -84,13 +84,16 @@ CompilationOptions compilationOptionsCreateDefault() { return wrap(new mlir::concretelang::CompilationOptions("main")); } +void compilationOptionsDestroy(CompilationOptions options){ + C_STRUCT_CLEANER(options)} + /// ********** OptimizerConfig CAPI ******************************************** -OptimizerConfig optimizerConfigCreate(bool display, - double fallback_log_norm_woppbs, - double global_p_error, double p_error, - uint64_t security, bool strategy_v0, - bool use_gpu_constraints) { +OptimizerConfig + optimizerConfigCreate(bool display, double fallback_log_norm_woppbs, + double global_p_error, double p_error, + uint64_t security, bool strategy_v0, + bool use_gpu_constraints) { auto config = new mlir::concretelang::optimizer::Config(); config->display = display; config->fallback_log_norm_woppbs = fallback_log_norm_woppbs; @@ -106,6 +109,8 @@ OptimizerConfig optimizerConfigCreateDefault() { return wrap(new mlir::concretelang::optimizer::Config()); } +void optimizerConfigDestroy(OptimizerConfig config){C_STRUCT_CLEANER(config)} + /// ********** CompilerEngine CAPI ********************************************* CompilerEngine compilerEngineCreate() { @@ -330,6 +335,10 @@ ClientParameters clientParametersUnserialize(BufferRef buffer) { return wrap(new mlir::concretelang::ClientParameters(paramsOrError.get())); } +ClientParameters clientParametersCopy(ClientParameters params) { + return wrap(new mlir::concretelang::ClientParameters(*unwrap(params))); +} + void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)} /// ********** KeySet CAPI *****************************************************