diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 63eeb1ef1..696244e0b 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -129,16 +129,27 @@ typedef enum CompilationTarget CompilationTarget; /// ********** CompilationOptions CAPI ***************************************** -MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate(); +MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate( + MlirStringRef funcName, bool autoParallelize, bool batchConcreteOps, + bool dataflowParallelize, bool emitGPUOps, bool loopParallelize, + bool optimizeConcrete, OptimizerConfig optimizerConfig, + bool verifyDiagnostics); MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreateDefault(); +MLIR_CAPI_EXPORTED void compilationOptionsDestroy(CompilationOptions options); + /// ********** OptimizerConfig CAPI ******************************************** -MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreate(); +MLIR_CAPI_EXPORTED OptimizerConfig +optimizerConfigCreate(bool display, double fallback_log_norm_woppbs, + double global_p_error, double p_error, uint64_t security, + bool strategy_v0, bool use_gpu_constraints); MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreateDefault(); +MLIR_CAPI_EXPORTED void optimizerConfigDestroy(OptimizerConfig config); + /// ********** CompilerEngine CAPI ********************************************* MLIR_CAPI_EXPORTED CompilerEngine compilerEngineCreate(); @@ -224,6 +235,9 @@ MLIR_CAPI_EXPORTED BufferRef clientParametersSerialize(ClientParameters params); MLIR_CAPI_EXPORTED ClientParameters clientParametersUnserialize(BufferRef buffer); +MLIR_CAPI_EXPORTED ClientParameters +clientParametersCopy(ClientParameters params); + MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params); /// ********** KeySet CAPI ***************************************************** diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index 477277775..a18d5ff79 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -1,6 +1,6 @@ //! Compiler module -use crate::mlir::ffi::*; +use crate::mlir::ffi; use std::os::raw::c_char; use std::{ffi::CStr, path::Path}; @@ -23,54 +23,934 @@ macro_rules! impl_CStructErrorMsg { } } impl_CStructErrorMsg! {[ - crate::mlir::ffi::LibrarySupport, - CompilationResult, - LibraryCompilationResult, - ServerLambda, - ClientParameters, - PublicArguments, - PublicResult, - KeySet, - KeySetCache, - LambdaArgument, - BufferRef, - EvaluationKeys + ffi::BufferRef, + ffi::CompilationOptions, + ffi::OptimizerConfig, + ffi::CompilerEngine, + ffi::CompilationResult, + ffi::Library, + ffi::LibraryCompilationResult, + ffi::LibrarySupport, + ffi::ServerLambda, + ffi::ClientParameters, + ffi::KeySet, + ffi::KeySetCache, + ffi::EvaluationKeys, + ffi::LambdaArgument, + ffi::PublicArguments, + ffi::PublicResult, + ffi::CompilationFeedback ]} /// Construct a rust error message from a buffer in the C struct. -fn get_error_msg_from_ctype(c_struct: T) -> String { +fn get_error_msg_from_ctype(c_struct: &T) -> String { unsafe { let error_msg_cstr = CStr::from_ptr(c_struct.error_msg()); String::from(error_msg_cstr.to_str().unwrap()) } } -/// Create string from an MlirStringRef and free its memory. -/// -/// # SAFETY -/// -/// This should only be used with string refs returned by the compiler. -unsafe fn mlir_string_ref_to_string(str_ref: MlirStringRef) -> String { - let result = String::from_utf8_lossy(std::slice::from_raw_parts( - str_ref.data as *const u8, - str_ref.length as usize, - )) - .to_string(); - mlirStringRefDestroy(str_ref); - result +/// Wrapper to own MlirStringRef coming from the compiler and destroy them on drop +struct MlirStringRef(ffi::MlirStringRef); + +impl MlirStringRef { + pub fn to_string(&self) -> Result { + unsafe { + if self.0.data.is_null() { + return Err(CompilerError("string ref points to null".to_string())); + } + let result = String::from_utf8_lossy(std::slice::from_raw_parts( + self.0.data as *const u8, + self.0.length as usize, + )) + .to_string(); + Ok(result) + } + } + + /// Create an ffi MlirStringRef for a rust str. + /// + /// The reason behind not returning a wrapper is that it would lead to freeing rust memory + /// using a custom destructor in C. + /// + /// # SAFETY + /// The caller has to make sure the &str outlive the ffi::MlirStringRef + pub unsafe fn from_rust_str(s: &str) -> ffi::MlirStringRef { + ffi::MlirStringRef { + data: s.as_ptr() as *const c_char, + length: s.len() as ffi::size_t, + } + } } -/// Create a vector of bytes from a BufferRef and free its memory. +impl Drop for MlirStringRef { + fn drop(&mut self) { + unsafe { ffi::mlirStringRefDestroy(self.0) } + } +} + +trait CStructWrapper { + // wrap a c-struct inside a rust-struct + fn wrap(c_struct: T) -> Self; + // check if the wrapped c-struct is null + fn is_null(&self) -> bool; + // get error message + fn error_msg(&self) -> String; + // drop + fn destroy(&mut self); +} + +/// Wrapper of CStruct. /// -/// # SAFETY -/// -/// This should only be used with string refs returned by the compiler. -unsafe fn buffer_ref_to_bytes(buffer_ref: BufferRef) -> Vec { - let result = - std::slice::from_raw_parts(buffer_ref.data as *const c_char, buffer_ref.length as usize) +/// We want to have a Rust wrapper for every CStruct that will take care of owning +/// it, and freeing memory when it's no longer used. +macro_rules! def_CStructWrapper { + ( + $name:ident => { + $ffi_is_null_fn:ident, + $ffi_destroy_fn:ident + $(,)? + } + ) => { + + pub struct $name{ _c: ffi::$name } + + impl CStructWrapper for $name { + // wrap a c-struct inside a rust-struct + fn wrap(c_struct: ffi::$name) -> Self { + Self{_c: c_struct} + } + // check if the wrapped C-struct is null + fn is_null(&self) -> bool { + unsafe { + ffi::$ffi_is_null_fn(self._c) + } + } + // get error message + fn error_msg(&self) -> String { + get_error_msg_from_ctype(&self._c) + } + // free memory allocated for the C-struct + fn destroy(&mut self) { + unsafe { + ffi::$ffi_destroy_fn(self._c) + } + } + } + + impl Drop for $name { + fn drop(&mut self) { + self.destroy(); + } + } + }; + + ( + $( + $name:ident => { + $ffi_is_null_fn:ident, + $ffi_destroy_fn:ident + $(,)? + } + ),+ + $(,)? + ) => { + $( + def_CStructWrapper!{ + $name => { + $ffi_is_null_fn, + $ffi_destroy_fn + } + } + )+ + }; +} +def_CStructWrapper! { + BufferRef => { + bufferRefIsNull, + bufferRefDestroy + }, + CompilationOptions => { + compilationOptionsIsNull, + compilationOptionsDestroy, + }, + OptimizerConfig => { + optimizerConfigIsNull, + optimizerConfigDestroy, + }, + CompilerEngine => { + compilerEngineIsNull, + compilerEngineDestroy, + }, + CompilationResult => { + compilationResultIsNull, + compilationResultDestroy, + }, + Library => { + libraryIsNull, + libraryDestroy, + }, + LibraryCompilationResult => { + libraryCompilationResultIsNull, + libraryCompilationResultDestroy, + }, + LibrarySupport => { + librarySupportIsNull, + librarySupportDestroy, + }, + ServerLambda => { + serverLambdaIsNull, + serverLambdaDestroy, + }, + ClientParameters => { + clientParametersIsNull, + clientParametersDestroy, + }, + KeySetCache => { + keySetCacheIsNull, + keySetCacheDestroy, + }, + EvaluationKeys => { + evaluationKeysIsNull, + evaluationKeysDestroy, + }, + LambdaArgument => { + lambdaArgumentIsNull, + lambdaArgumentDestroy, + }, + PublicArguments => { + publicArgumentsIsNull, + publicArgumentsDestroy, + }, + PublicResult => { + publicResultIsNull, + publicResultDestroy, + }, + CompilationFeedback => { + compilationFeedbackIsNull, + compilationFeedbackDestroy, + } +} + +impl BufferRef { + /// Create a reference to a buffer in memory. + /// + /// The pointed memory will not get owned. The caller must make sure the pointer points + /// to a valid memory region of the provided length, and that the pointed memory outlive + /// the buffer reference. + pub fn new(ptr: *const c_char, length: ffi::size_t) -> Result { + unsafe { + let buffer_ref = ffi::bufferRefCreate(ptr, length); + if ffi::bufferRefIsNull(buffer_ref) { + let error_msg = get_error_msg_from_ctype(&buffer_ref); + ffi::bufferRefDestroy(buffer_ref); + return Err(CompilerError(error_msg)); + } + return Ok(buffer_ref); + } + } + + /// Copy the content of the buffer into a new vector of bytes. + /// + /// Returns an empty vector if the buffer reference a null pointer. + pub fn to_bytes(&self) -> Vec { + if self.is_null() { + return Vec::new(); + } + let buffer_ref_c = self._c; + unsafe { + let result = std::slice::from_raw_parts( + buffer_ref_c.data as *const c_char, + buffer_ref_c.length as usize, + ) .to_vec(); - bufferRefDestroy(buffer_ref); - result + result + } + } +} + +impl CompilationOptions { + pub fn new( + func_name: &str, + auto_parallelize: bool, + batch_concrete_ops: bool, + dataflow_parallelize: bool, + emit_gpu_ops: bool, + loop_parallelize: bool, + optimize_concrete: bool, + optimizer_config: &OptimizerConfig, + verify_diagnostics: bool, + ) -> Result { + unsafe { + let options = CompilationOptions::wrap(ffi::compilationOptionsCreate( + MlirStringRef::from_rust_str(func_name), + auto_parallelize, + batch_concrete_ops, + dataflow_parallelize, + emit_gpu_ops, + loop_parallelize, + optimize_concrete, + optimizer_config._c, + verify_diagnostics, + )); + if options.is_null() { + return Err(CompilerError(options.error_msg())); + } + Ok(options) + } + } + + pub fn get_default() -> Result { + unsafe { + let options = CompilationOptions::wrap(ffi::compilationOptionsCreateDefault()); + if options.is_null() { + return Err(CompilerError(options.error_msg())); + } + Ok(options) + } + } +} + +impl OptimizerConfig { + pub fn new( + display: bool, + fallback_log_norm_woppbs: f64, + global_p_error: f64, + p_error: f64, + security: u64, + strategy_v0: bool, + use_gpu_constraints: bool, + ) -> Result { + unsafe { + let config = OptimizerConfig::wrap(ffi::optimizerConfigCreate( + display, + fallback_log_norm_woppbs, + global_p_error, + p_error, + security, + strategy_v0, + use_gpu_constraints, + )); + if config.is_null() { + return Err(CompilerError(config.error_msg())); + } + Ok(config) + } + } + + pub fn get_default() -> Result { + unsafe { + let config = OptimizerConfig::wrap(ffi::optimizerConfigCreateDefault()); + if config.is_null() { + return Err(CompilerError(config.error_msg())); + } + Ok(config) + } + } +} +impl CompilerEngine { + pub fn new(options: Option<&CompilationOptions>) -> Result { + unsafe { + let engine = CompilerEngine::wrap(ffi::compilerEngineCreate()); + if engine.is_null() { + return Err(CompilerError(engine.error_msg())); + } + if let Some(o) = options { + engine.set_options(o) + } + Ok(engine) + } + } + + pub fn set_options(&self, options: &CompilationOptions) { + unsafe { + ffi::compilerEngineCompileSetOptions(self._c, options._c); + } + } + + pub fn compile( + &self, + module: &str, + target: ffi::CompilationTarget, + ) -> Result { + unsafe { + let module_string_ref = MlirStringRef::from_rust_str(module); + let result = CompilationResult::wrap(ffi::compilerEngineCompile( + self._c, + module_string_ref, + target, + )); + if result.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + result.error_msg() + ))); + } + Ok(result) + } + } +} +impl CompilationResult { + pub fn get_module_string(&self) -> Result { + unsafe { MlirStringRef(ffi::compilationResultGetModuleString(self._c)).to_string() } + } +} +impl Library { + pub fn new( + output_dir_path: &str, + runtime_library_path: Option<&str>, + clean_up: bool, + ) -> Result { + unsafe { + let lib = Library::wrap(ffi::libraryCreate( + MlirStringRef::from_rust_str(output_dir_path), + MlirStringRef::from_rust_str(runtime_library_path.unwrap_or("")), + clean_up, + )); + if lib.is_null() { + return Err(CompilerError(lib.error_msg())); + } + Ok(lib) + } + } +} + +impl LibraryCompilationResult {} + +/// Support for compiling and executing libraries. +impl LibrarySupport { + /// LibrarySupport manages build files generated by the compiler under the `output_dir_path`. + /// + /// The compiled library needs to link to the runtime for proper execution. + pub fn new( + output_dir_path: &str, + runtime_library_path: Option, + ) -> Result { + unsafe { + let runtime_library_path = match runtime_library_path { + Some(val) => val.to_string(), + None => "".to_string(), + }; + let runtime_library_path_buffer = runtime_library_path.as_str(); + let support = LibrarySupport::wrap(ffi::librarySupportCreateDefault( + MlirStringRef::from_rust_str(output_dir_path), + MlirStringRef::from_rust_str(runtime_library_path_buffer), + )); + if support.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + support.error_msg() + ))); + } + Ok(support) + } + } + + /// Compile an MLIR into a library. + pub fn compile( + &self, + mlir_code: &str, + options: Option, + ) -> Result { + unsafe { + let options = options.unwrap_or_else(|| CompilationOptions::get_default().unwrap()); + let result = LibraryCompilationResult::wrap(ffi::librarySupportCompile( + self._c, + MlirStringRef::from_rust_str(mlir_code), + options._c, + )); + if result.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + result.error_msg() + ))); + } + Ok(result) + } + } + + /// Load server lambda from a compilation result. + /// + /// This can be used for executing the compiled function. + pub fn load_server_lambda( + &self, + result: &LibraryCompilationResult, + ) -> Result { + unsafe { + let server = + ServerLambda::wrap(ffi::librarySupportLoadServerLambda(self._c, result._c)); + if server.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + server.error_msg() + ))); + } + Ok(server) + } + } + + /// Load client parameters from a compilation result. + /// + /// This can be used for creating keys for the compiled library. + pub fn load_client_parameters( + &self, + result: &LibraryCompilationResult, + ) -> Result { + unsafe { + let params = + ClientParameters::wrap(ffi::librarySupportLoadClientParameters(self._c, result._c)); + if params.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + params.error_msg() + ))); + } + Ok(params) + } + } + + /// Run a compiled circuit. + pub fn server_lambda_call( + &self, + server_lambda: &ServerLambda, + args: &PublicArguments, + eval_keys: &EvaluationKeys, + ) -> Result { + unsafe { + let result = PublicResult::wrap(ffi::librarySupportServerCall( + self._c, + server_lambda._c, + args._c, + eval_keys._c, + )); + if result.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + result.error_msg() + ))); + } + Ok(result) + } + } + + /// Get path to the compiled shared library + pub fn get_shared_lib_path(&self) -> String { + unsafe { + MlirStringRef(ffi::librarySupportGetSharedLibPath(self._c)) + .to_string() + .unwrap() + } + } + + /// Get path to the client parameters + pub fn get_client_parameters_path(&self) -> String { + unsafe { + MlirStringRef(ffi::librarySupportGetClientParametersPath(self._c)) + .to_string() + .unwrap() + } + } +} + +impl ServerLambda {} + +impl ClientParameters { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = BufferRef::wrap(ffi::clientParametersSerialize(self._c)); + if serialized_ref.is_null() { + return Err(CompilerError(serialized_ref.error_msg())); + } + Ok(serialized_ref.to_bytes()) + } + } + pub fn unserialize(serialized: &Vec) -> Result { + unsafe { + let serialized_ref = BufferRef::new( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ) + .unwrap(); + let params = ClientParameters::wrap(ffi::clientParametersUnserialize(serialized_ref)); + if params.is_null() { + return Err(CompilerError(params.error_msg())); + } + Ok(params) + } + } +} + +impl Clone for ClientParameters { + fn clone(&self) -> Self { + unsafe { ClientParameters::wrap(ffi::clientParametersCopy(self._c)) } + } +} + +struct KeySet_ { + _c: ffi::KeySet, +} + +impl CStructWrapper for KeySet_ { + // wrap a c-struct inside a rust-struct + fn wrap(c_struct: ffi::KeySet) -> KeySet_ { + KeySet_ { _c: c_struct } + } + // check if the wrapped C-struct is null + fn is_null(&self) -> bool { + unsafe { ffi::keySetIsNull(self._c) } + } + // get error message + fn error_msg(&self) -> String { + get_error_msg_from_ctype(&self._c) + } + // free memory allocated for the C-struct + fn destroy(&mut self) { + unsafe { ffi::keySetDestroy(self._c) } + } +} + +impl Drop for KeySet_ { + fn drop(&mut self) { + self.destroy(); + } +} +pub struct KeySet { + key_set: KeySet_, + client_params: ClientParameters, +} + +impl KeySet { + /// Get a keyset based on the client parameters, and the different seeds. + /// + /// If a cache is set, this operation would first try to load an existing key, + /// otherwise, a new keyset will be generated. + pub fn new( + client_params: &ClientParameters, + seed_msb: Option, + seed_lsb: Option, + key_set_cache: Option<&KeySetCache>, + ) -> Result { + unsafe { + let key_set = match key_set_cache { + Some(cache) => KeySet_::wrap(ffi::keySetCacheLoadOrGenerateKeySet( + cache._c, + client_params._c, + seed_msb.unwrap_or(0), + seed_lsb.unwrap_or(0), + )), + None => KeySet_::wrap(ffi::keySetGenerate( + client_params._c, + seed_msb.unwrap_or(0), + seed_lsb.unwrap_or(0), + )), + }; + if key_set.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + key_set.error_msg() + ))); + } + Ok(KeySet { + key_set, + client_params: client_params.clone(), + }) + } + } + + pub fn get_evaluation_keys(&self) -> Result { + unsafe { + let eval_keys = EvaluationKeys::wrap(ffi::keySetGetEvaluationKeys(self.key_set._c)); + if eval_keys.is_null() { + return Err(CompilerError(eval_keys.error_msg())); + } + Ok(eval_keys) + } + } + + /// Encrypt arguments of a compiled circuit. + pub fn encrypt_args(&self, args: &[LambdaArgument]) -> Result { + LambdaArgument::encrypt_args(args, self) + } + + pub fn decrypt_result(&self, result: &PublicResult) -> Result { + result.decrypt(self) + } +} + +impl KeySetCache { + pub fn new(path: &Path) -> Result { + unsafe { + let cache_path_buffer = path.to_str().unwrap(); + let cache = KeySetCache::wrap(ffi::keySetCacheCreate(MlirStringRef::from_rust_str( + cache_path_buffer, + ))); + if cache.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + cache.error_msg() + ))); + } + Ok(cache) + } + } +} + +impl EvaluationKeys { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = BufferRef::wrap(ffi::evaluationKeysSerialize(self._c)); + if serialized_ref.is_null() { + return Err(CompilerError(serialized_ref.error_msg())); + } + Ok(serialized_ref.to_bytes()) + } + } + pub fn unserialize(serialized: &Vec) -> Result { + unsafe { + let serialized_ref = BufferRef::new( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ) + .unwrap(); + let eval_keys = EvaluationKeys::wrap(ffi::evaluationKeysUnserialize(serialized_ref)); + if eval_keys.is_null() { + return Err(CompilerError(eval_keys.error_msg())); + } + Ok(eval_keys) + } + } +} + +impl LambdaArgument { + pub fn encrypt_args( + args: &[LambdaArgument], + key_set: &KeySet, + ) -> Result { + unsafe { + let args: Vec = args.into_iter().map(|a| a._c).collect(); + let public_args = PublicArguments::wrap(ffi::lambdaArgumentEncrypt( + args.as_ptr(), + args.len() as u64, + key_set.client_params._c, + key_set.key_set._c, + )); + if public_args.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + public_args.error_msg() + ))); + } + Ok(public_args) + } + } + + pub fn from_scalar(scalar: u64) -> Result { + unsafe { + let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromScalar(scalar)); + if arg.is_null() { + return Err(CompilerError(arg.error_msg())); + } + Ok(arg) + } + } + + pub fn is_scalar(&self) -> bool { + unsafe { ffi::lambdaArgumentIsScalar(self._c) } + } + + pub fn get_scalar(&self) -> Result { + unsafe { + if !self.is_scalar() { + return Err(CompilerError("argument is not a scalar".to_string())); + } + Ok(ffi::lambdaArgumentGetScalar(self._c)) + } + } + + pub fn from_tensor_u8(data: &[u8], dims: &[i64]) -> Result { + unsafe { + let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromTensorU8( + data.as_ptr(), + dims.as_ptr(), + dims.len().try_into().unwrap(), + )); + if arg.is_null() { + return Err(CompilerError(arg.error_msg())); + } + Ok(arg) + } + } + + pub fn from_tensor_u64(data: &[u64], dims: &[i64]) -> Result { + unsafe { + let arg = LambdaArgument::wrap(ffi::lambdaArgumentFromTensorU64( + data.as_ptr(), + dims.as_ptr(), + dims.len().try_into().unwrap(), + )); + if arg.is_null() { + return Err(CompilerError(arg.error_msg())); + } + Ok(arg) + } + } + + pub fn is_tensor(&self) -> bool { + unsafe { ffi::lambdaArgumentIsTensor(self._c) } + } + + pub fn get_data_size(&self) -> Result { + unsafe { + if !self.is_tensor() { + return Err(CompilerError("argument is not a tensor".to_string())); + } + Ok(ffi::lambdaArgumentGetTensorDataSize(self._c)) + } + } + + pub fn get_rank(&self) -> Result { + unsafe { + if !self.is_tensor() { + return Err(CompilerError("argument is not a tensor".to_string())); + } + Ok(ffi::lambdaArgumentGetTensorRank(self._c)) + } + } + + pub fn get_dims(&self) -> Result, CompilerError> { + unsafe { + let rank = self.get_rank().unwrap(); + let mut dims = Vec::new(); + dims.resize(rank.try_into().unwrap(), 0); + if !ffi::lambdaArgumentGetTensorDims(self._c, dims.as_mut_ptr()) { + return Err(CompilerError("couldn't get dims".to_string())); + } + Ok(dims) + } + } + + pub fn get_data(&self) -> Result, CompilerError> { + unsafe { + let size = self.get_data_size().unwrap(); + let mut data = Vec::new(); + data.resize(size.try_into().unwrap(), 0); + if !ffi::lambdaArgumentGetTensorData(self._c, data.as_mut_ptr()) { + return Err(CompilerError("couldn't get data".to_string())); + } + Ok(data) + } + } +} + +impl PublicArguments { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = BufferRef::wrap(ffi::publicArgumentsSerialize(self._c)); + if serialized_ref.is_null() { + return Err(CompilerError(serialized_ref.error_msg())); + } + Ok(serialized_ref.to_bytes()) + } + } + pub fn unserialize( + serialized: &Vec, + client_parameters: &ClientParameters, + ) -> Result { + unsafe { + let serialized_ref = BufferRef::new( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ) + .unwrap(); + let public_args = PublicArguments::wrap(ffi::publicArgumentsUnserialize( + serialized_ref, + client_parameters._c, + )); + if public_args.is_null() { + return Err(CompilerError(public_args.error_msg())); + } + Ok(public_args) + } + } +} + +impl PublicResult { + pub fn serialize(self) -> Result, CompilerError> { + unsafe { + let serialized_ref = BufferRef::wrap(ffi::publicResultSerialize(self._c)); + if serialized_ref.is_null() { + return Err(CompilerError(serialized_ref.error_msg())); + } + Ok(serialized_ref.to_bytes()) + } + } + pub fn unserialize( + serialized: &Vec, + client_parameters: &ClientParameters, + ) -> Result { + unsafe { + let serialized_ref = BufferRef::new( + serialized.as_ptr() as *const c_char, + serialized.len().try_into().unwrap(), + ) + .unwrap(); + let public_result = PublicResult::wrap(ffi::publicResultUnserialize( + serialized_ref, + client_parameters._c, + )); + if public_result.is_null() { + return Err(CompilerError(public_result.error_msg())); + } + Ok(public_result) + } + } + + pub fn decrypt(&self, key_set: &KeySet) -> Result { + unsafe { + let arg = LambdaArgument::wrap(ffi::publicResultDecrypt(self._c, key_set.key_set._c)); + if arg.is_null() { + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + arg.error_msg() + ))); + } + Ok(arg) + } + } +} + +impl CompilationFeedback { + pub fn get_complexity(&self) -> f64 { + unsafe { ffi::compilationFeedbackGetComplexity(self._c) } + } + + pub fn get_p_error(&self) -> f64 { + unsafe { ffi::compilationFeedbackGetPError(self._c) } + } + + pub fn get_global_p_error(&self) -> f64 { + unsafe { ffi::compilationFeedbackGetGlobalPError(self._c) } + } + + pub fn get_total_secret_keys_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalSecretKeysSize(self._c) } + } + + pub fn get_total_bootstrap_keys_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalBootstrapKeysSize(self._c) } + } + + pub fn get_total_keyswitch_keys_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalKeyswitchKeysSize(self._c) } + } + + pub fn get_total_inputs_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalInputsSize(self._c) } + } + + pub fn get_total_outputs_size(&self) -> u64 { + unsafe { ffi::compilationFeedbackGetTotalOutputsSize(self._c) } + } } /// Parse the MLIR code and returns it. @@ -91,479 +971,9 @@ unsafe fn buffer_ref_to_bytes(buffer_ref: BufferRef) -> Vec { /// ``` /// pub fn round_trip(mlir_code: &str) -> Result { - unsafe { - let engine = compilerEngineCreate(); - let mlir_code_buffer = mlir_code.as_bytes(); - let compilation_result = compilerEngineCompile( - engine, - MlirStringRef { - data: mlir_code_buffer.as_ptr() as *const c_char, - length: mlir_code_buffer.len() as size_t, - }, - CompilationTarget_ROUND_TRIP, - ); - if compilationResultIsNull(compilation_result) { - let error_msg = get_error_msg_from_ctype(compilation_result); - compilationResultDestroy(compilation_result); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - let module_compiled = compilationResultGetModuleString(compilation_result); - let result_str = mlir_string_ref_to_string(module_compiled); - compilerEngineDestroy(engine); - Ok(result_str) - } -} - -/// Support for compiling and executing libraries. -pub struct LibrarySupport { - support: crate::mlir::ffi::LibrarySupport, -} - -impl Drop for LibrarySupport { - fn drop(&mut self) { - unsafe { - librarySupportDestroy(self.support); - } - } -} - -impl LibrarySupport { - /// LibrarySupport manages build files generated by the compiler under the `output_dir_path`. - /// - /// The compiled library needs to link to the runtime for proper execution. - pub fn new( - output_dir_path: &str, - runtime_library_path: &str, - ) -> Result { - unsafe { - let output_dir_path_buffer = output_dir_path.as_bytes(); - let runtime_library_path_buffer = runtime_library_path.as_bytes(); - let support = librarySupportCreateDefault( - MlirStringRef { - data: output_dir_path_buffer.as_ptr() as *const c_char, - length: output_dir_path_buffer.len() as size_t, - }, - MlirStringRef { - data: runtime_library_path_buffer.as_ptr() as *const c_char, - length: runtime_library_path_buffer.len() as size_t, - }, - ); - if librarySupportIsNull(support) { - let error_msg = get_error_msg_from_ctype(support); - librarySupportDestroy(support); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(LibrarySupport { support }) - } - } - - /// Compile an MLIR into a library. - pub fn compile( - &self, - mlir_code: &str, - options: Option, - ) -> Result { - unsafe { - let options = options.unwrap_or_else(|| compilationOptionsCreateDefault()); - let mlir_code_buffer = mlir_code.as_bytes(); - let result = librarySupportCompile( - self.support, - MlirStringRef { - data: mlir_code_buffer.as_ptr() as *const c_char, - length: mlir_code_buffer.len() as size_t, - }, - options, - ); - if libraryCompilationResultIsNull(result) { - let error_msg = get_error_msg_from_ctype(result); - libraryCompilationResultDestroy(result); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(result) - } - } - - /// Load server lambda from a compilation result. - /// - /// This can be used for executing the compiled function. - pub fn load_server_lambda( - &self, - result: LibraryCompilationResult, - ) -> Result { - unsafe { - let server = librarySupportLoadServerLambda(self.support, result); - if serverLambdaIsNull(server) { - let error_msg = get_error_msg_from_ctype(server); - serverLambdaDestroy(server); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(server) - } - } - - /// Load client parameters from a compilation result. - /// - /// This can be used for creating keys for the compiled library. - pub fn load_client_parameters( - &self, - result: LibraryCompilationResult, - ) -> Result { - unsafe { - let params = librarySupportLoadClientParameters(self.support, result); - if clientParametersIsNull(params) { - let error_msg = get_error_msg_from_ctype(params); - clientParametersDestroy(params); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(params) - } - } - - /// Run a compiled circuit. - pub fn server_lambda_call( - &self, - server_lambda: ServerLambda, - args: PublicArguments, - eval_keys: EvaluationKeys, - ) -> Result { - unsafe { - let result = librarySupportServerCall(self.support, server_lambda, args, eval_keys); - if publicResultIsNull(result) { - let error_msg = get_error_msg_from_ctype(result); - publicResultDestroy(result); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(result) - } - } - - /// Get path to the compiled shared library - pub fn get_shared_lib_path(&self) -> String { - unsafe { mlir_string_ref_to_string(librarySupportGetSharedLibPath(self.support)) } - } - - /// Get path to the client parameters - pub fn get_client_parameters_path(&self) -> String { - unsafe { mlir_string_ref_to_string(librarySupportGetClientParametersPath(self.support)) } - } -} - -/// Support for keygen, encryption, and decryption. -/// -/// Manages cache for keys if provided during creation. -pub struct ClientSupport { - client_params: crate::mlir::ffi::ClientParameters, - key_set_cache: Option, -} - -impl Drop for ClientSupport { - fn drop(&mut self) { - unsafe { - clientParametersDestroy(self.client_params); - match self.key_set_cache { - Some(cache) => keySetCacheDestroy(cache), - None => (), - } - } - } -} - -impl ClientSupport { - pub fn new( - client_params: ClientParameters, - key_set_cache_path: Option<&Path>, - ) -> Result { - unsafe { - let key_set_cache = match key_set_cache_path { - Some(path) => { - let cache_path_buffer = path.to_str().unwrap().as_bytes(); - let cache = keySetCacheCreate(MlirStringRef { - data: cache_path_buffer.as_ptr() as *const c_char, - length: cache_path_buffer.len() as size_t, - }); - if keySetCacheIsNull(cache) { - let error_msg = get_error_msg_from_ctype(cache); - keySetCacheDestroy(cache); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Some(cache) - } - None => None, - }; - Ok(ClientSupport { - client_params, - key_set_cache, - }) - } - } - - /// Fetch a keyset based on the client parameters, and the different seeds. - /// - /// If a cache has already been set, this operation would first try to load an existing key, - /// and generate a new one if no compatible keyset exists. - pub fn keyset( - &self, - seed_msb: Option, - seed_lsb: Option, - ) -> Result { - unsafe { - let key_set = match self.key_set_cache { - Some(cache) => keySetCacheLoadOrGenerateKeySet( - cache, - self.client_params, - seed_msb.unwrap_or(0), - seed_lsb.unwrap_or(0), - ), - None => keySetGenerate( - self.client_params, - seed_msb.unwrap_or(0), - seed_lsb.unwrap_or(0), - ), - }; - if keySetIsNull(key_set) { - let error_msg = get_error_msg_from_ctype(key_set); - keySetDestroy(key_set); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(key_set) - } - } - - /// Encrypt arguments of a compiled circuit. - pub fn encrypt_args( - &self, - args: &[LambdaArgument], - key_set: KeySet, - ) -> Result { - unsafe { - let public_args = lambdaArgumentEncrypt( - args.as_ptr(), - args.len() as u64, - self.client_params, - key_set, - ); - if publicArgumentsIsNull(public_args) { - let error_msg = get_error_msg_from_ctype(public_args); - publicArgumentsDestroy(public_args); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(public_args) - } - } - - pub fn decrypt_result( - &self, - result: PublicResult, - key_set: KeySet, - ) -> Result { - unsafe { - let arg = publicResultDecrypt(result, key_set); - if lambdaArgumentIsNull(arg) { - let error_msg = get_error_msg_from_ctype(arg); - lambdaArgumentDestroy(arg); - return Err(CompilerError(format!( - "Error in compiler (check logs for more info): {}", - error_msg - ))); - } - Ok(arg) - } - } -} - -// TODO: implement traits for C Struct that are serializable and reduce code for serialization and maybe refactor other functions. -// destroy and is_null could be implemented for other struct as well. -// -// trait Serializable { -// fn into_buffer_ref(self) -> BufferRef; -// fn from_buffer_ref(buff: BufferRef, params: Option) -> Self; -// fn is_null(self) -> bool; -// fn destroy(self); -// } - -// fn serialize(to_serialize: T) -> Result, CompilerError> { -// unsafe { -// let serialized_ref = to_serialize.into_buffer_ref(); -// if bufferRefIsNull(serialized_ref) { -// let error_msg = get_error_msg_from_ctype(serialized_ref); -// bufferRefDestroy(serialized_ref); -// return Err(CompilerError(error_msg)); -// } -// let serialized = buffer_ref_to_bytes(serialized_ref); -// Ok(serialized) -// } -// } - -// fn unserialize( -// serialized: &Vec, -// client_parameters: Option, -// ) -> Result { -// unsafe { -// let serialized_ref = bufferRefCreate( -// serialized.as_ptr() as *const c_char, -// serialized.len().try_into().unwrap(), -// ); -// let serialized = T::from_buffer_ref(serialized_ref, client_parameters); -// if serialized.is_null() { -// let error_msg = get_error_msg_from_ctype(serialized); -// serialized.destroy(); -// return Err(CompilerError(error_msg)); -// } -// Ok(serialized) -// } -// } - -impl PublicArguments { - pub fn serialize(self) -> Result, CompilerError> { - unsafe { - let serialized_ref = publicArgumentsSerialize(self); - if bufferRefIsNull(serialized_ref) { - let error_msg = get_error_msg_from_ctype(serialized_ref); - bufferRefDestroy(serialized_ref); - return Err(CompilerError(error_msg)); - } - let serialized = buffer_ref_to_bytes(serialized_ref); - Ok(serialized) - } - } - pub fn unserialize( - serialized: &Vec, - client_parameters: ClientParameters, - ) -> Result { - unsafe { - let serialized_ref = bufferRefCreate( - serialized.as_ptr() as *const c_char, - serialized.len().try_into().unwrap(), - ); - let public_args = publicArgumentsUnserialize(serialized_ref, client_parameters); - if publicArgumentsIsNull(public_args) { - let error_msg = get_error_msg_from_ctype(public_args); - publicArgumentsDestroy(public_args); - return Err(CompilerError(error_msg)); - } - Ok(public_args) - } - } -} - -impl PublicResult { - pub fn serialize(self) -> Result, CompilerError> { - unsafe { - let serialized_ref = publicResultSerialize(self); - if bufferRefIsNull(serialized_ref) { - let error_msg = get_error_msg_from_ctype(serialized_ref); - bufferRefDestroy(serialized_ref); - return Err(CompilerError(error_msg)); - } - let serialized = buffer_ref_to_bytes(serialized_ref); - Ok(serialized) - } - } - pub fn unserialize( - serialized: &Vec, - client_parameters: ClientParameters, - ) -> Result { - unsafe { - let serialized_ref = bufferRefCreate( - serialized.as_ptr() as *const c_char, - serialized.len().try_into().unwrap(), - ); - let public_result = publicResultUnserialize(serialized_ref, client_parameters); - if publicResultIsNull(public_result) { - let error_msg = get_error_msg_from_ctype(public_result); - publicResultDestroy(public_result); - return Err(CompilerError(error_msg)); - } - Ok(public_result) - } - } -} - -impl EvaluationKeys { - pub fn serialize(self) -> Result, CompilerError> { - unsafe { - let serialized_ref = evaluationKeysSerialize(self); - if bufferRefIsNull(serialized_ref) { - let error_msg = get_error_msg_from_ctype(serialized_ref); - bufferRefDestroy(serialized_ref); - return Err(CompilerError(error_msg)); - } - let serialized = buffer_ref_to_bytes(serialized_ref); - Ok(serialized) - } - } - pub fn unserialize(serialized: &Vec) -> Result { - unsafe { - let serialized_ref = bufferRefCreate( - serialized.as_ptr() as *const c_char, - serialized.len().try_into().unwrap(), - ); - let eval_keys = evaluationKeysUnserialize(serialized_ref); - if evaluationKeysIsNull(eval_keys) { - let error_msg = get_error_msg_from_ctype(eval_keys); - evaluationKeysDestroy(eval_keys); - return Err(CompilerError(error_msg)); - } - Ok(eval_keys) - } - } -} - -impl ClientParameters { - pub fn serialize(self) -> Result, CompilerError> { - unsafe { - let serialized_ref = clientParametersSerialize(self); - if bufferRefIsNull(serialized_ref) { - let error_msg = get_error_msg_from_ctype(serialized_ref); - bufferRefDestroy(serialized_ref); - return Err(CompilerError(error_msg)); - } - let serialized = buffer_ref_to_bytes(serialized_ref); - Ok(serialized) - } - } - pub fn unserialize(serialized: &Vec) -> Result { - unsafe { - let serialized_ref = bufferRefCreate( - serialized.as_ptr() as *const c_char, - serialized.len().try_into().unwrap(), - ); - let params = clientParametersUnserialize(serialized_ref); - if clientParametersIsNull(params) { - let error_msg = get_error_msg_from_ctype(params); - clientParametersDestroy(params); - return Err(CompilerError(error_msg)); - } - Ok(params) - } - } + let engine = CompilerEngine::new(None).unwrap(); + let compilation_result = engine.compile(mlir_code, ffi::CompilationTarget_ROUND_TRIP)?; + compilation_result.get_module_string() } #[cfg(test)] @@ -573,6 +983,13 @@ mod test { use super::*; + fn get_runtime_lib_path() -> Option { + match env::var("CONCRETE_COMPILER_INSTALL_DIR") { + Ok(val) => Some(val + "/lib/libConcretelangRuntime.so"), + Err(_e) => None, + } + } + #[test] fn test_compiler_round_trip() { let module_to_compile = " @@ -602,262 +1019,192 @@ mod test { #[test] fn test_compiler_compile_lib() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = TempDir::new("rust_test_compiler_compile_lib").unwrap(); - let support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) - .unwrap(); - let lib = support.compile(module_to_compile, None).unwrap(); - assert!(!libraryCompilationResultIsNull(lib)); - libraryCompilationResultDestroy(lib); - // the sharedlib should be enough as a sign that the compilation worked - assert!(Path::new(support.get_shared_lib_path().as_str()).exists()); - assert!(Path::new(support.get_client_parameters_path().as_str()).exists()); - } + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + let lib = support.compile(module_to_compile, None).unwrap(); + assert!(!lib.is_null()); + // the sharedlib should be enough as a sign that the compilation worked + assert!(Path::new(support.get_shared_lib_path().as_str()).exists()); + assert!(Path::new(support.get_client_parameters_path().as_str()).exists()); } /// We want to make sure setting a pointer to null in rust passes the nullptr check in C/Cpp #[test] fn test_compiler_null_ptr_compatibility() { unsafe { - let lib = Library { + let lib = ffi::Library { ptr: std::ptr::null_mut(), error: std::ptr::null_mut(), }; - assert!(libraryIsNull(lib)); + assert!(ffi::libraryIsNull(lib)); } } #[test] fn test_compiler_load_server_lambda_and_client_parameters() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = TempDir::new("rust_test_compiler_load_server_lambda").unwrap(); - let support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) - .unwrap(); - let result = support.compile(module_to_compile, None).unwrap(); - let server = support.load_server_lambda(result).unwrap(); - assert!(!serverLambdaIsNull(server)); - serverLambdaDestroy(server); - let client_params = support.load_client_parameters(result).unwrap(); - assert!(!clientParametersIsNull(client_params)); - libraryCompilationResultDestroy(result); - } + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + let result = support.compile(module_to_compile, None).unwrap(); + let server = support.load_server_lambda(&result).unwrap(); + assert!(!server.is_null()); + let client_params = support.load_client_parameters(&result).unwrap(); + assert!(!client_params.is_null()); } #[test] fn test_compiler_compile_and_exec_scalar_args() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_scalar_args").unwrap(); - let lib_support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let lib_support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(&result).unwrap(); + let client_params = lib_support.load_client_parameters(&result).unwrap(); + let key_set = KeySet::new(&client_params, None, None, None).unwrap(); + let eval_keys = key_set.get_evaluation_keys().unwrap(); + // build lambda arguments from scalar and encrypt them + let args = [ + LambdaArgument::from_scalar(4).unwrap(), + LambdaArgument::from_scalar(2).unwrap(), + ]; + let encrypted_args = key_set.encrypt_args(&args).unwrap(); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(&server_lambda, &encrypted_args, &eval_keys) .unwrap(); - // compile - let result = lib_support.compile(module_to_compile, None).unwrap(); - // loading materials from compilation - // - server_lambda: used for execution - // - client_parameters: used for keygen, encryption, and evaluation keys - let server_lambda = lib_support.load_server_lambda(result).unwrap(); - let client_params = lib_support.load_client_parameters(result).unwrap(); - let client_support = ClientSupport::new(client_params, None).unwrap(); - let key_set = client_support.keyset(None, None).unwrap(); - let eval_keys = keySetGetEvaluationKeys(key_set); - // build lambda arguments from scalar and encrypt them - let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)]; - let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); - // free args - args.map(|arg| lambdaArgumentDestroy(arg)); - // execute the compiled function on the encrypted arguments - let encrypted_result = lib_support - .server_lambda_call(server_lambda, encrypted_args, eval_keys) - .unwrap(); - // decrypt the result of execution - let result_arg = client_support - .decrypt_result(encrypted_result, key_set) - .unwrap(); - // get the scalar value from the result lambda argument - let result = lambdaArgumentGetScalar(result_arg); - assert_eq!(result, 6); - } + // decrypt the result of execution + let result_arg = key_set.decrypt_result(&encrypted_result).unwrap(); + // get the scalar value from the result lambda argument + let result = result_arg.get_scalar().unwrap(); + assert_eq!(result, 6); } #[test] fn test_compiler_compile_and_exec_with_serialization() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = - TempDir::new("rust_test_compiler_compile_and_exec_with_serialization").unwrap(); - let lib_support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let lib_support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(&result).unwrap(); + let client_params = lib_support.load_client_parameters(&result).unwrap(); + // serialize client parameters + let serialized_params = client_params.serialize().unwrap(); + let client_params = ClientParameters::unserialize(&serialized_params).unwrap(); + // generate keys + let key_set = KeySet::new(&client_params, None, None, None).unwrap(); + let eval_keys = key_set.get_evaluation_keys().unwrap(); + // serialize eval keys + let serialized_eval_keys = eval_keys.serialize().unwrap(); + let eval_keys = EvaluationKeys::unserialize(&serialized_eval_keys).unwrap(); + // build lambda arguments from scalar and encrypt them + let args = [ + LambdaArgument::from_scalar(4).unwrap(), + LambdaArgument::from_scalar(2).unwrap(), + ]; + let encrypted_args = key_set.encrypt_args(&args).unwrap(); + // serialize args + let serialized_encrypted_args = encrypted_args.serialize().unwrap(); + let encrypted_args = + PublicArguments::unserialize(&serialized_encrypted_args, &client_params).unwrap(); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(&server_lambda, &encrypted_args, &eval_keys) .unwrap(); - // compile - let result = lib_support.compile(module_to_compile, None).unwrap(); - // loading materials from compilation - // - server_lambda: used for execution - // - client_parameters: used for keygen, encryption, and evaluation keys - let server_lambda = lib_support.load_server_lambda(result).unwrap(); - let client_params = lib_support.load_client_parameters(result).unwrap(); - // serialize client parameters - let serialized_params = client_params.serialize().unwrap(); - let client_params = ClientParameters::unserialize(&serialized_params).unwrap(); - // create client support - let client_support = ClientSupport::new(client_params, None).unwrap(); - let key_set = client_support.keyset(None, None).unwrap(); - let eval_keys = keySetGetEvaluationKeys(key_set); - // serialize eval keys - let serialized_eval_keys = eval_keys.serialize().unwrap(); - let eval_keys = EvaluationKeys::unserialize(&serialized_eval_keys).unwrap(); - // build lambda arguments from scalar and encrypt them - let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)]; - let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); - // free args - args.map(|arg| lambdaArgumentDestroy(arg)); - // serialize args - let serialized_encrypted_args = encrypted_args.serialize().unwrap(); - let encrypted_args = - PublicArguments::unserialize(&serialized_encrypted_args, client_params).unwrap(); - // execute the compiled function on the encrypted arguments - let encrypted_result = lib_support - .server_lambda_call(server_lambda, encrypted_args, eval_keys) - .unwrap(); - // serialize result - let serialized_encrypted_result = encrypted_result.serialize().unwrap(); - let encrypted_result = - PublicResult::unserialize(&serialized_encrypted_result, client_params).unwrap(); - // decrypt the result of execution - let result_arg = client_support - .decrypt_result(encrypted_result, key_set) - .unwrap(); - // get the scalar value from the result lambda argument - let result = lambdaArgumentGetScalar(result_arg); - assert_eq!(result, 6); - } + // serialize result + let serialized_encrypted_result = encrypted_result.serialize().unwrap(); + let encrypted_result = + PublicResult::unserialize(&serialized_encrypted_result, &client_params).unwrap(); + // decrypt the result of execution + let result_arg = key_set.decrypt_result(&encrypted_result).unwrap(); + // get the scalar value from the result lambda argument + let result = result_arg.get_scalar().unwrap(); + assert_eq!(result, 6); } #[test] fn test_tensor_lambda_argument() { - unsafe { - let tensor_data = [1, 2, 3, 73u64]; - let tensor_dims = [2, 2i64]; - let tensor_arg = - lambdaArgumentFromTensorU64(tensor_data.as_ptr(), tensor_dims.as_ptr(), 2); - assert!(!lambdaArgumentIsNull(tensor_arg)); - assert!(!lambdaArgumentIsScalar(tensor_arg)); - assert!(lambdaArgumentIsTensor(tensor_arg)); - assert_eq!(lambdaArgumentGetTensorRank(tensor_arg), 2); - assert_eq!(lambdaArgumentGetTensorDataSize(tensor_arg), 4); - let mut dims: [i64; 2] = [0, 0]; - assert_eq!( - lambdaArgumentGetTensorDims(tensor_arg, dims.as_mut_ptr()), - true - ); - assert_eq!(dims, tensor_dims); - - let mut data: [u64; 4] = [0; 4]; - assert_eq!( - lambdaArgumentGetTensorData(tensor_arg, data.as_mut_ptr()), - true - ); - assert_eq!(data, tensor_data); - lambdaArgumentDestroy(tensor_arg); - } + let tensor_data = [1, 2, 3, 73u64]; + let tensor_dims = [2, 2i64]; + let tensor_arg = LambdaArgument::from_tensor_u64(&tensor_data, &tensor_dims).unwrap(); + assert!(!tensor_arg.is_null()); + assert!(!tensor_arg.is_scalar()); + assert!(tensor_arg.is_tensor()); + assert_eq!(tensor_arg.get_rank().unwrap(), 2); + assert_eq!(tensor_arg.get_data_size().unwrap(), 4); + assert_eq!(tensor_arg.get_dims().unwrap(), tensor_dims); + assert_eq!(tensor_arg.get_data().unwrap(), tensor_data); } #[test] fn test_compiler_compile_and_exec_tensor_args() { - unsafe { - let module_to_compile = " + let module_to_compile = " func.func @main(%arg0: tensor<2x3x!FHE.eint<5>>, %arg1: tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> { %0 = \"FHELinalg.add_eint\"(%arg0, %arg1) : (tensor<2x3x!FHE.eint<5>>, tensor<2x3x!FHE.eint<5>>) -> tensor<2x3x!FHE.eint<5>> return %0 : tensor<2x3x!FHE.eint<5>> }"; - let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { - Ok(val) => val + "/lib/libConcretelangRuntime.so", - Err(_e) => "".to_string(), - }; - let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_tensor_args").unwrap(); - let lib_support = LibrarySupport::new( - temp_dir.path().to_str().unwrap(), - runtime_library_path.as_str(), - ) + let runtime_library_path = get_runtime_lib_path(); + let temp_dir = TempDir::new("concrete_compiler_rust_test").unwrap(); + let lib_support = + LibrarySupport::new(temp_dir.path().to_str().unwrap(), runtime_library_path).unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(&result).unwrap(); + let client_params = lib_support.load_client_parameters(&result).unwrap(); + let key_set = KeySet::new(&client_params, None, None, None).unwrap(); + let eval_keys = key_set.get_evaluation_keys().unwrap(); + // build lambda arguments from scalar and encrypt them + let args = [ + LambdaArgument::from_tensor_u8(&[1, 2, 3, 4, 5, 6], &[2, 3]).unwrap(), + LambdaArgument::from_tensor_u8(&[1, 4, 7, 4, 2, 9], &[2, 3]).unwrap(), + ]; + let encrypted_args = key_set.encrypt_args(&args).unwrap(); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(&server_lambda, &encrypted_args, &eval_keys) .unwrap(); - // compile - let result = lib_support.compile(module_to_compile, None).unwrap(); - // loading materials from compilation - // - server_lambda: used for execution - // - client_parameters: used for keygen, encryption, and evaluation keys - let server_lambda = lib_support.load_server_lambda(result).unwrap(); - let client_params = lib_support.load_client_parameters(result).unwrap(); - let client_support = ClientSupport::new(client_params, None).unwrap(); - let key_set = client_support.keyset(None, None).unwrap(); - let eval_keys = keySetGetEvaluationKeys(key_set); - // build lambda arguments from scalar and encrypt them - let args = [ - lambdaArgumentFromTensorU8([1, 2, 3, 4, 5, 6].as_ptr(), [2, 3].as_ptr(), 2), - lambdaArgumentFromTensorU8([1, 4, 7, 4, 2, 9].as_ptr(), [2, 3].as_ptr(), 2), - ]; - let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); - // execute the compiled function on the encrypted arguments - let encrypted_result = lib_support - .server_lambda_call(server_lambda, encrypted_args, eval_keys) - .unwrap(); - // decrypt the result of execution - let result_arg = client_support - .decrypt_result(encrypted_result, key_set) - .unwrap(); - // check the tensor dims value from the result lambda argument - assert_eq!(lambdaArgumentGetTensorRank(result_arg), 2); - assert_eq!(lambdaArgumentGetTensorDataSize(result_arg), 6); - let mut dims = [0, 0]; - assert!(lambdaArgumentGetTensorDims(result_arg, dims.as_mut_ptr())); - assert_eq!(dims, [2, 3]); - // check the tensor data from the result lambda argument - let mut data = [0; 6]; - assert!(lambdaArgumentGetTensorData(result_arg, data.as_mut_ptr())); - assert_eq!(data, [2, 6, 10, 8, 7, 15]); - } + // decrypt the result of execution + let result_arg = key_set.decrypt_result(&encrypted_result).unwrap(); + // check the tensor dims value from the result lambda argument + assert_eq!(result_arg.get_rank().unwrap(), 2); + assert_eq!(result_arg.get_data_size().unwrap(), 6); + assert_eq!(result_arg.get_dims().unwrap(), [2, 3]); + // check the tensor data from the result lambda argument + assert_eq!(result_arg.get_data().unwrap(), [2, 6, 10, 8, 7, 15]); } } diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 2adbe6a0b..c90840328 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -84,13 +84,16 @@ CompilationOptions compilationOptionsCreateDefault() { return wrap(new mlir::concretelang::CompilationOptions("main")); } +void compilationOptionsDestroy(CompilationOptions options){ + C_STRUCT_CLEANER(options)} + /// ********** OptimizerConfig CAPI ******************************************** -OptimizerConfig optimizerConfigCreate(bool display, - double fallback_log_norm_woppbs, - double global_p_error, double p_error, - uint64_t security, bool strategy_v0, - bool use_gpu_constraints) { +OptimizerConfig + optimizerConfigCreate(bool display, double fallback_log_norm_woppbs, + double global_p_error, double p_error, + uint64_t security, bool strategy_v0, + bool use_gpu_constraints) { auto config = new mlir::concretelang::optimizer::Config(); config->display = display; config->fallback_log_norm_woppbs = fallback_log_norm_woppbs; @@ -106,6 +109,8 @@ OptimizerConfig optimizerConfigCreateDefault() { return wrap(new mlir::concretelang::optimizer::Config()); } +void optimizerConfigDestroy(OptimizerConfig config){C_STRUCT_CLEANER(config)} + /// ********** CompilerEngine CAPI ********************************************* CompilerEngine compilerEngineCreate() { @@ -330,6 +335,10 @@ ClientParameters clientParametersUnserialize(BufferRef buffer) { return wrap(new mlir::concretelang::ClientParameters(paramsOrError.get())); } +ClientParameters clientParametersCopy(ClientParameters params) { + return wrap(new mlir::concretelang::ClientParameters(*unwrap(params))); +} + void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)} /// ********** KeySet CAPI *****************************************************