diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index b489d03df..bb91b7ec5 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -33,6 +33,13 @@ DEFINE_C_API_STRUCT(LibrarySupport, void); DEFINE_C_API_STRUCT(CompilationOptions, void); DEFINE_C_API_STRUCT(OptimizerConfig, void); DEFINE_C_API_STRUCT(ServerLambda, void); +DEFINE_C_API_STRUCT(ClientParameters, void); +DEFINE_C_API_STRUCT(KeySet, void); +DEFINE_C_API_STRUCT(KeySetCache, void); +DEFINE_C_API_STRUCT(EvaluationKeys, void); +DEFINE_C_API_STRUCT(LambdaArgument, void); +DEFINE_C_API_STRUCT(PublicArguments, void); +DEFINE_C_API_STRUCT(PublicResult, void); #undef DEFINE_C_API_STRUCT @@ -51,6 +58,13 @@ DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport); DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions); DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig); DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda); +DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters); +DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet); +DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache); +DEFINE_NULL_PTR_CHECKER(evaluationKeysIsNull, EvaluationKeys); +DEFINE_NULL_PTR_CHECKER(lambdaArgumentIsNull, LambdaArgument); +DEFINE_NULL_PTR_CHECKER(publicArgumentsIsNull, PublicArguments); +DEFINE_NULL_PTR_CHECKER(publicResultIsNull, PublicResult); #undef DEFINE_NULL_PTR_CHECKER @@ -146,10 +160,81 @@ MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile( MLIR_CAPI_EXPORTED ServerLambda librarySupportLoadServerLambda( LibrarySupport support, LibraryCompilationResult result); +MLIR_CAPI_EXPORTED ClientParameters librarySupportLoadClientParameters( + LibrarySupport support, LibraryCompilationResult result); + +MLIR_CAPI_EXPORTED PublicResult +librarySupportServerCall(LibrarySupport support, ServerLambda server, + PublicArguments args, EvaluationKeys evalKeys); + +MLIR_CAPI_EXPORTED void librarySupportDestroy(LibrarySupport support); + /// ********** ServerLamda CAPI ************************************************ MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server); +/// ********** ClientParameters CAPI ******************************************* + +MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params); + +/// ********** KeySet CAPI ***************************************************** + +MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params, + uint64_t seed_msb, uint64_t seed_lsb); + +MLIR_CAPI_EXPORTED EvaluationKeys keySetGetEvaluationKeys(KeySet keySet); + +MLIR_CAPI_EXPORTED void keySetDestroy(KeySet keySet); + +/// ********** KeySetCache CAPI ************************************************ + +MLIR_CAPI_EXPORTED KeySetCache keySetCacheCreate(MlirStringRef cachePath); + +MLIR_CAPI_EXPORTED KeySet +keySetCacheLoadOrGenerateKeySet(KeySetCache cache, ClientParameters params, + uint64_t seed_msb, uint64_t seed_lsb); + +MLIR_CAPI_EXPORTED void keySetCacheDestroy(KeySetCache keySetCache); + +/// ********** EvaluationKeys CAPI ********************************************* + +MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys); + +/// ********** LambdaArgument CAPI ********************************************* + +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value); + +MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, + int64_t *dims, + size_t rank); + +MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg); +MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg); + +MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg); +MLIR_CAPI_EXPORTED uint64_t * +lambdaArgumentGetTensorData(LambdaArgument lambdaArg); +MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg); +MLIR_CAPI_EXPORTED int64_t * +lambdaArgumentGetTensorDims(LambdaArgument lambdaArg); + +MLIR_CAPI_EXPORTED PublicArguments +lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber, + ClientParameters params, KeySet keySet); + +MLIR_CAPI_EXPORTED void lambdaArgumentDestroy(LambdaArgument lambdaArg); + +/// ********** PublicArguments CAPI ******************************************** + +MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs); + +/// ********** PublicResult CAPI *********************************************** + +MLIR_CAPI_EXPORTED LambdaArgument publicResultDecrypt(PublicResult publicResult, + KeySet keySet); + +MLIR_CAPI_EXPORTED void publicResultDestroy(PublicResult publicResult); + #ifdef __cplusplus } #endif diff --git a/compiler/include/concretelang/CAPI/Wrappers.h b/compiler/include/concretelang/CAPI/Wrappers.h index 8ab82bb85..6b15e1b7d 100644 --- a/compiler/include/concretelang/CAPI/Wrappers.h +++ b/compiler/include/concretelang/CAPI/Wrappers.h @@ -25,5 +25,17 @@ DEFINE_C_API_PTR_METHODS(CompilationOptions, DEFINE_C_API_PTR_METHODS(OptimizerConfig, mlir::concretelang::optimizer::Config) DEFINE_C_API_PTR_METHODS(ServerLambda, mlir::concretelang::serverlib::ServerLambda) +DEFINE_C_API_PTR_METHODS(ClientParameters, + mlir::concretelang::clientlib::ClientParameters) +DEFINE_C_API_PTR_METHODS(KeySet, mlir::concretelang::clientlib::KeySet) +DEFINE_C_API_PTR_METHODS(KeySetCache, + mlir::concretelang::clientlib::KeySetCache) +DEFINE_C_API_PTR_METHODS(EvaluationKeys, + mlir::concretelang::clientlib::EvaluationKeys) +DEFINE_C_API_PTR_METHODS(LambdaArgument, mlir::concretelang::LambdaArgument) +DEFINE_C_API_PTR_METHODS(PublicArguments, + mlir::concretelang::clientlib::PublicArguments) +DEFINE_C_API_PTR_METHODS(PublicResult, + mlir::concretelang::clientlib::PublicResult) #endif diff --git a/compiler/include/concretelang/ClientLib/KeySetCache.h b/compiler/include/concretelang/ClientLib/KeySetCache.h index 989b15a83..948aad54d 100644 --- a/compiler/include/concretelang/ClientLib/KeySetCache.h +++ b/compiler/include/concretelang/ClientLib/KeySetCache.h @@ -24,6 +24,9 @@ public: generate(std::shared_ptr optionalCache, ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb); + outcome::checked, StringError> + generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb); + private: static outcome::checked, StringError> loadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb, diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index 2752c86c6..82b5a6fb9 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -1,12 +1,11 @@ //! Compiler module +use std::path::Path; + use crate::mlir::ffi::*; #[derive(Debug)] -pub struct CompilationError(String); - -#[derive(Debug)] -pub struct ServerLambdaLoadError(String); +pub struct CompilerError(String); /// Parse the MLIR code and returns it. /// @@ -25,7 +24,7 @@ pub struct ServerLambdaLoadError(String); /// let result_str = round_trip(module_to_compile); /// ``` /// -pub fn round_trip(mlir_code: &str) -> Result { +pub fn round_trip(mlir_code: &str) -> Result { unsafe { let engine = compilerEngineCreate(); let mlir_code_buffer = mlir_code.as_bytes(); @@ -38,7 +37,7 @@ pub fn round_trip(mlir_code: &str) -> Result { CompilationTarget_ROUND_TRIP, ); if compilationResultIsNull(compilation_result) { - return Err(CompilationError("roundtrip error".to_string())); + return Err(CompilerError("roundtrip error".to_string())); } let module_compiled = compilationResultGetModuleString(compilation_result); let result_str = String::from_utf8_lossy(std::slice::from_raw_parts( @@ -57,26 +56,39 @@ pub struct LibrarySupport { support: crate::mlir::ffi::LibrarySupport, } +impl Drop for LibrarySupport { + fn drop(&mut self) { + unsafe { + librarySupportDestroy(self.support); + } + } +} + impl LibrarySupport { /// LibrarySupport manages build files generated by the compiler under the `output_dir_path`. /// /// The compiled library needs to link to the runtime for proper execution. - pub fn new(output_dir_path: &str, runtime_library_path: &str) -> LibrarySupport { + pub fn new( + output_dir_path: &str, + runtime_library_path: &str, + ) -> Result { unsafe { let output_dir_path_buffer = output_dir_path.as_bytes(); let runtime_library_path_buffer = runtime_library_path.as_bytes(); - LibrarySupport { - support: librarySupportCreateDefault( - MlirStringRef { - data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char, - length: output_dir_path_buffer.len() as size_t, - }, - MlirStringRef { - data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char, - length: runtime_library_path_buffer.len() as size_t, - }, - ), + let support = librarySupportCreateDefault( + MlirStringRef { + data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char, + length: output_dir_path_buffer.len() as size_t, + }, + MlirStringRef { + data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char, + length: runtime_library_path_buffer.len() as size_t, + }, + ); + if librarySupportIsNull(support) { + return Err(CompilerError("failed creating library support".to_string())); } + Ok(LibrarySupport { support }) } } @@ -85,7 +97,7 @@ impl LibrarySupport { &self, mlir_code: &str, options: Option, - ) -> Result { + ) -> Result { unsafe { let options = options.unwrap_or_else(|| compilationOptionsCreateDefault()); let mlir_code_buffer = mlir_code.as_bytes(); @@ -98,7 +110,7 @@ impl LibrarySupport { options, ); if libraryCompilationResultIsNull(result) { - return Err(CompilationError("library compilation failed".to_string())); + return Err(CompilerError("library compilation failed".to_string())); } Ok(result) } @@ -110,17 +122,163 @@ impl LibrarySupport { pub fn load_server_lambda( &self, result: LibraryCompilationResult, - ) -> Result { + ) -> Result { unsafe { let server = librarySupportLoadServerLambda(self.support, result); if serverLambdaIsNull(server) { - return Err(ServerLambdaLoadError( - "loading server lambda failed".to_string(), - )); + return Err(CompilerError("loading server lambda failed".to_string())); } Ok(server) } } + + /// Load client parameters from a compilation result. + /// + /// This can be used for creating keys for the compiled library. + pub fn load_client_parameters( + &self, + result: LibraryCompilationResult, + ) -> Result { + unsafe { + let params = librarySupportLoadClientParameters(self.support, result); + if clientParametersIsNull(params) { + return Err(CompilerError( + "loading client parameters failed".to_string(), + )); + } + Ok(params) + } + } + + /// Run a compiled circuit. + pub fn server_lambda_call( + &self, + server_lambda: ServerLambda, + args: PublicArguments, + eval_keys: EvaluationKeys, + ) -> Result { + unsafe { + let result = librarySupportServerCall(self.support, server_lambda, args, eval_keys); + if publicResultIsNull(result) { + return Err(CompilerError("failed calling server lambda".to_string())); + } + Ok(result) + } + } +} + +/// Support for keygen, encryption, and decryption. +/// +/// Manages cache for keys if provided during creation. +pub struct ClientSupport { + client_params: crate::mlir::ffi::ClientParameters, + key_set_cache: Option, +} + +impl Drop for ClientSupport { + fn drop(&mut self) { + unsafe { + clientParametersDestroy(self.client_params); + match self.key_set_cache { + Some(cache) => keySetCacheDestroy(cache), + None => (), + } + } + } +} + +impl ClientSupport { + pub fn new( + client_params: ClientParameters, + key_set_cache_path: Option<&Path>, + ) -> Result { + unsafe { + let key_set_cache = match key_set_cache_path { + Some(path) => { + let cache_path_buffer = path.to_str().unwrap().as_bytes(); + let cache = keySetCacheCreate(MlirStringRef { + data: cache_path_buffer.as_ptr() as *const std::os::raw::c_char, + length: cache_path_buffer.len() as size_t, + }); + if keySetCacheIsNull(cache) { + return Err(CompilerError( + "failed creating keyset cache from path".to_string(), + )); + } + Some(cache) + } + None => None, + }; + Ok(ClientSupport { + client_params, + key_set_cache, + }) + } + } + + /// Fetch a keyset based on the client parameters, and the different seeds. + /// + /// If a cache has already been set, this operation would first try to load an existing key, + /// and generate a new one if no compatible keyset exists. + pub fn keyset( + &self, + seed_msb: Option, + seed_lsb: Option, + ) -> Result { + unsafe { + let key_set = match self.key_set_cache { + Some(cache) => keySetCacheLoadOrGenerateKeySet( + cache, + self.client_params, + seed_msb.unwrap_or(0), + seed_lsb.unwrap_or(0), + ), + None => keySetGenerate( + self.client_params, + seed_msb.unwrap_or(0), + seed_lsb.unwrap_or(0), + ), + }; + if keySetIsNull(key_set) { + return Err(CompilerError("getting keyset failed".to_string())); + } + Ok(key_set) + } + } + + /// Encrypt arguments of a compiled circuit. + pub fn encrypt_args( + &self, + args: &[LambdaArgument], + key_set: KeySet, + ) -> Result { + unsafe { + let public_args = lambdaArgumentEncrypt( + args.as_ptr(), + args.len() as u64, + self.client_params, + key_set, + ); + if publicArgumentsIsNull(public_args) { + return Err(CompilerError("encryption failed".to_string())); + } + Ok(public_args) + } + } + + pub fn decrypt_result( + &self, + result: PublicResult, + key_set: KeySet, + ) -> Result { + unsafe { + let arg = publicResultDecrypt(result, key_set); + if lambdaArgumentIsNull(arg) { + return Err(CompilerError("decryption failed".to_string())); + } + Ok(arg) + } + } } #[cfg(test)] @@ -152,7 +310,7 @@ mod test { fn test_compiler_round_trip_invalid_mlir() { let module_to_compile = "bla bla bla"; let result_str = round_trip(module_to_compile); - assert!(matches!(result_str, Err(CompilationError(_)))); + assert!(matches!(result_str, Err(CompilerError(_)))); } #[test] @@ -171,7 +329,8 @@ mod test { let support = LibrarySupport::new( temp_dir.path().to_str().unwrap(), runtime_library_path.as_str(), - ); + ) + .unwrap(); let lib = support.compile(module_to_compile, None).unwrap(); assert!(!libraryCompilationResultIsNull(lib)); libraryCompilationResultDestroy(lib); @@ -192,7 +351,7 @@ mod test { } #[test] - fn test_compiler_load_server_lambda() { + fn test_compiler_load_server_lambda_and_client_parameters() { unsafe { let module_to_compile = " func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { @@ -207,12 +366,62 @@ mod test { let support = LibrarySupport::new( temp_dir.path().to_str().unwrap(), runtime_library_path.as_str(), - ); + ) + .unwrap(); let result = support.compile(module_to_compile, None).unwrap(); let server = support.load_server_lambda(result).unwrap(); assert!(!serverLambdaIsNull(server)); - libraryCompilationResultDestroy(result); serverLambdaDestroy(server); + let client_params = support.load_client_parameters(result).unwrap(); + assert!(!clientParametersIsNull(client_params)); + libraryCompilationResultDestroy(result); + } + } + + #[test] + fn test_compiler_compile_and_exec_scalar_args() { + unsafe { + let module_to_compile = " + func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> + return %0 : !FHE.eint<5> + }"; + let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { + Ok(val) => val + "/lib/libConcretelangRuntime.so", + Err(_e) => "".to_string(), + }; + let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_scalar_args").unwrap(); + let lib_support = LibrarySupport::new( + temp_dir.path().to_str().unwrap(), + runtime_library_path.as_str(), + ) + .unwrap(); + // compile + let result = lib_support.compile(module_to_compile, None).unwrap(); + // loading materials from compilation + // - server_lambda: used for execution + // - client_parameters: used for keygen, encryption, and evaluation keys + let server_lambda = lib_support.load_server_lambda(result).unwrap(); + let client_params = lib_support.load_client_parameters(result).unwrap(); + let client_support = ClientSupport::new(client_params, None).unwrap(); + let key_set = client_support.keyset(None, None).unwrap(); + let eval_keys = keySetGetEvaluationKeys(key_set); + // build lambda arguments from scalar and encrypt them + let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)]; + let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap(); + // free args + args.map(|arg| lambdaArgumentDestroy(arg)); + // execute the compiled function on the encrypted arguments + let encrypted_result = lib_support + .server_lambda_call(server_lambda, encrypted_args, eval_keys) + .unwrap(); + // decrypt the result of execution + let result_arg = client_support + .decrypt_result(encrypted_result, key_set) + .unwrap(); + // get the scalar value from the result lambda argument + let result = lambdaArgumentGetScalar(result_arg); + assert_eq!(result, 6); } } } diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index cdbab3bcb..15d54c136 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -7,6 +7,7 @@ #include "concretelang/CAPI/Wrappers.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Error.h" +#include "concretelang/Support/LambdaSupport.h" #include "mlir/IR/Diagnostics.h" #include "llvm/Support/SourceMgr.h" @@ -186,7 +187,7 @@ LibraryCompilationResult librarySupportCompile(LibrarySupport support, return wrap((mlir::concretelang::LibraryCompilationResult *)NULL); } return wrap(new mlir::concretelang::LibraryCompilationResult( - *retOrError.get().get())); + *retOrError.get().release())); } ServerLambda librarySupportLoadServerLambda(LibrarySupport support, @@ -200,6 +201,169 @@ ServerLambda librarySupportLoadServerLambda(LibrarySupport support, serverLambdaOrError.get())); } +ClientParameters +librarySupportLoadClientParameters(LibrarySupport support, + LibraryCompilationResult result) { + auto paramsOrError = unwrap(support)->loadClientParameters(*unwrap(result)); + if (!paramsOrError) { + llvm::errs() << llvm::toString(paramsOrError.takeError()); + return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL); + } + return wrap( + new mlir::concretelang::clientlib::ClientParameters(paramsOrError.get())); +} + +PublicResult librarySupportServerCall(LibrarySupport support, + ServerLambda server_lambda, + PublicArguments args, + EvaluationKeys evalKeys) { + auto resultOrError = unwrap(support)->serverCall( + *unwrap(server_lambda), *unwrap(args), *unwrap(evalKeys)); + if (!resultOrError) { + llvm::errs() << llvm::toString(resultOrError.takeError()); + return wrap((mlir::concretelang::clientlib::PublicResult *)NULL); + } + return wrap(resultOrError.get().release()); +} + +void librarySupportDestroy(LibrarySupport support) { delete unwrap(support); } + /// ********** ServerLamda CAPI ************************************************ void serverLambdaDestroy(ServerLambda server) { delete unwrap(server); } + +/// ********** ClientParameters CAPI ******************************************* + +void clientParametersDestroy(ClientParameters params) { delete unwrap(params); } + +/// ********** KeySet CAPI ***************************************************** + +KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb, + uint64_t seed_lsb) { + auto keySet = mlir::concretelang::clientlib::KeySet::generate( + *unwrap(params), seed_msb, seed_lsb); + if (keySet.has_error()) { + llvm::errs() << keySet.error().mesg; + return wrap((mlir::concretelang::clientlib::KeySet *)NULL); + } + return wrap(keySet.value().release()); +} + +EvaluationKeys keySetGetEvaluationKeys(KeySet keySet) { + return wrap(new mlir::concretelang::clientlib::EvaluationKeys( + unwrap(keySet)->evaluationKeys())); +} + +void keySetDestroy(KeySet keySet) { delete unwrap(keySet); } + +/// ********** KeySetCache CAPI ************************************************ + +KeySetCache keySetCacheCreate(MlirStringRef cachePath) { + std::string cachePathStr(cachePath.data, cachePath.length); + return wrap(new mlir::concretelang::clientlib::KeySetCache(cachePathStr)); +} + +KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache, + ClientParameters params, + uint64_t seed_msb, uint64_t seed_lsb) { + auto keySetOrError = + unwrap(cache)->generate(*unwrap(params), seed_msb, seed_lsb); + if (keySetOrError.has_error()) { + llvm::errs() << keySetOrError.error().mesg; + return wrap((mlir::concretelang::clientlib::KeySet *)NULL); + } + return wrap(keySetOrError.value().release()); +} + +void keySetCacheDestroy(KeySetCache keySetCache) { delete unwrap(keySetCache); } + +/// ********** EvaluationKeys CAPI ********************************************* + +void evaluationKeysDestroy(EvaluationKeys evaluationKeys) { + delete unwrap(evaluationKeys); +} + +/// ********** LambdaArgument CAPI ********************************************* + +LambdaArgument lambdaArgumentFromScalar(uint64_t value) { + return wrap(new mlir::concretelang::IntLambdaArgument(value)); +} + +// LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims, +// size_t rank); + +bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) { + return unwrap(lambdaArg) + ->isa>(); +} + +uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg) { + mlir::concretelang::IntLambdaArgument *arg = + unwrap(lambdaArg) + ->dyn_cast>(); + assert(arg != nullptr && "lambda argument isn't a scalar"); + return arg->getValue(); +} + +bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) { + return unwrap(lambdaArg) + ->isa>>() || + unwrap(lambdaArg) + ->isa>>() || + unwrap(lambdaArg) + ->isa>>() || + unwrap(lambdaArg) + ->isa>>(); +} + +// uint64_t *lambdaArgumentGetTensorData(LambdaArgument lambdaArg); +// size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg); +// int64_t *lambdaArgumentGetTensorDims(LambdaArgument lambdaArg); + +PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, + size_t argNumber, ClientParameters params, + KeySet keySet) { + std::vector args; + for (size_t i = 0; i < argNumber; i++) + args.push_back(unwrap(lambdaArgs[i])); + auto publicArgsOrError = + mlir::concretelang::LambdaSupport::exportArguments( + *unwrap(params), *unwrap(keySet), args); + if (!publicArgsOrError) { + llvm::errs() << llvm::toString(publicArgsOrError.takeError()); + return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL); + } + return wrap(publicArgsOrError.get().release()); +} + +void lambdaArgumentDestroy(LambdaArgument lambdaArg) { + delete unwrap(lambdaArg); +} + +/// ********** PublicArguments CAPI ******************************************** + +void publicArgumentsDestroy(PublicArguments publicArgs) { + delete unwrap(publicArgs); +} + +/// ********** PublicResult CAPI *********************************************** + +LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) { + llvm::Expected> + lambdaArgOrError = mlir::concretelang::typedResult< + std::unique_ptr>( + *unwrap(keySet), *unwrap(publicResult)); + if (!lambdaArgOrError) { + llvm::errs() << llvm::toString(lambdaArgOrError.takeError()); + return wrap((mlir::concretelang::LambdaArgument *)NULL); + } + return wrap(lambdaArgOrError.get().release()); +} + +void publicResultDestroy(PublicResult publicResult) { + delete unwrap(publicResult); +} diff --git a/compiler/lib/ClientLib/KeySetCache.cpp b/compiler/lib/ClientLib/KeySetCache.cpp index 79a3f29be..b1dd36317 100644 --- a/compiler/lib/ClientLib/KeySetCache.cpp +++ b/compiler/lib/ClientLib/KeySetCache.cpp @@ -353,5 +353,11 @@ KeySetCache::generate(std::shared_ptr cache, : KeySet::generate(params, seed_msb, seed_lsb); } +outcome::checked, StringError> +KeySetCache::generate(ClientParameters ¶ms, uint64_t seed_msb, + uint64_t seed_lsb) { + return loadOrGenerateSave(params, seed_msb, seed_lsb); +} + } // namespace clientlib } // namespace concretelang