mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(rust): support keygen, encryption, execution
This commit is contained in:
@@ -33,6 +33,13 @@ DEFINE_C_API_STRUCT(LibrarySupport, void);
|
||||
DEFINE_C_API_STRUCT(CompilationOptions, void);
|
||||
DEFINE_C_API_STRUCT(OptimizerConfig, void);
|
||||
DEFINE_C_API_STRUCT(ServerLambda, void);
|
||||
DEFINE_C_API_STRUCT(ClientParameters, void);
|
||||
DEFINE_C_API_STRUCT(KeySet, void);
|
||||
DEFINE_C_API_STRUCT(KeySetCache, void);
|
||||
DEFINE_C_API_STRUCT(EvaluationKeys, void);
|
||||
DEFINE_C_API_STRUCT(LambdaArgument, void);
|
||||
DEFINE_C_API_STRUCT(PublicArguments, void);
|
||||
DEFINE_C_API_STRUCT(PublicResult, void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
|
||||
@@ -51,6 +58,13 @@ DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport);
|
||||
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions);
|
||||
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig);
|
||||
DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda);
|
||||
DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters);
|
||||
DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet);
|
||||
DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache);
|
||||
DEFINE_NULL_PTR_CHECKER(evaluationKeysIsNull, EvaluationKeys);
|
||||
DEFINE_NULL_PTR_CHECKER(lambdaArgumentIsNull, LambdaArgument);
|
||||
DEFINE_NULL_PTR_CHECKER(publicArgumentsIsNull, PublicArguments);
|
||||
DEFINE_NULL_PTR_CHECKER(publicResultIsNull, PublicResult);
|
||||
|
||||
#undef DEFINE_NULL_PTR_CHECKER
|
||||
|
||||
@@ -146,10 +160,81 @@ MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile(
|
||||
MLIR_CAPI_EXPORTED ServerLambda librarySupportLoadServerLambda(
|
||||
LibrarySupport support, LibraryCompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED ClientParameters librarySupportLoadClientParameters(
|
||||
LibrarySupport support, LibraryCompilationResult result);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicResult
|
||||
librarySupportServerCall(LibrarySupport support, ServerLambda server,
|
||||
PublicArguments args, EvaluationKeys evalKeys);
|
||||
|
||||
MLIR_CAPI_EXPORTED void librarySupportDestroy(LibrarySupport support);
|
||||
|
||||
/// ********** ServerLamda CAPI ************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server);
|
||||
|
||||
/// ********** ClientParameters CAPI *******************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params);
|
||||
|
||||
/// ********** KeySet CAPI *****************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
MLIR_CAPI_EXPORTED EvaluationKeys keySetGetEvaluationKeys(KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED void keySetDestroy(KeySet keySet);
|
||||
|
||||
/// ********** KeySetCache CAPI ************************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySetCache keySetCacheCreate(MlirStringRef cachePath);
|
||||
|
||||
MLIR_CAPI_EXPORTED KeySet
|
||||
keySetCacheLoadOrGenerateKeySet(KeySetCache cache, ClientParameters params,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
MLIR_CAPI_EXPORTED void keySetCacheDestroy(KeySetCache keySetCache);
|
||||
|
||||
/// ********** EvaluationKeys CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
|
||||
|
||||
/// ********** LambdaArgument CAPI *********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data,
|
||||
int64_t *dims,
|
||||
size_t rank);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED uint64_t *
|
||||
lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
|
||||
MLIR_CAPI_EXPORTED int64_t *
|
||||
lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
|
||||
|
||||
MLIR_CAPI_EXPORTED PublicArguments
|
||||
lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber,
|
||||
ClientParameters params, KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED void lambdaArgumentDestroy(LambdaArgument lambdaArg);
|
||||
|
||||
/// ********** PublicArguments CAPI ********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs);
|
||||
|
||||
/// ********** PublicResult CAPI ***********************************************
|
||||
|
||||
MLIR_CAPI_EXPORTED LambdaArgument publicResultDecrypt(PublicResult publicResult,
|
||||
KeySet keySet);
|
||||
|
||||
MLIR_CAPI_EXPORTED void publicResultDestroy(PublicResult publicResult);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -25,5 +25,17 @@ DEFINE_C_API_PTR_METHODS(CompilationOptions,
|
||||
DEFINE_C_API_PTR_METHODS(OptimizerConfig, mlir::concretelang::optimizer::Config)
|
||||
DEFINE_C_API_PTR_METHODS(ServerLambda,
|
||||
mlir::concretelang::serverlib::ServerLambda)
|
||||
DEFINE_C_API_PTR_METHODS(ClientParameters,
|
||||
mlir::concretelang::clientlib::ClientParameters)
|
||||
DEFINE_C_API_PTR_METHODS(KeySet, mlir::concretelang::clientlib::KeySet)
|
||||
DEFINE_C_API_PTR_METHODS(KeySetCache,
|
||||
mlir::concretelang::clientlib::KeySetCache)
|
||||
DEFINE_C_API_PTR_METHODS(EvaluationKeys,
|
||||
mlir::concretelang::clientlib::EvaluationKeys)
|
||||
DEFINE_C_API_PTR_METHODS(LambdaArgument, mlir::concretelang::LambdaArgument)
|
||||
DEFINE_C_API_PTR_METHODS(PublicArguments,
|
||||
mlir::concretelang::clientlib::PublicArguments)
|
||||
DEFINE_C_API_PTR_METHODS(PublicResult,
|
||||
mlir::concretelang::clientlib::PublicResult)
|
||||
|
||||
#endif
|
||||
|
||||
@@ -24,6 +24,9 @@ public:
|
||||
generate(std::shared_ptr<KeySetCache> optionalCache, ClientParameters ¶ms,
|
||||
uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
generate(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb);
|
||||
|
||||
private:
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
loadKeys(ClientParameters ¶ms, uint64_t seed_msb, uint64_t seed_lsb,
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
//! Compiler module
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use crate::mlir::ffi::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CompilationError(String);
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ServerLambdaLoadError(String);
|
||||
pub struct CompilerError(String);
|
||||
|
||||
/// Parse the MLIR code and returns it.
|
||||
///
|
||||
@@ -25,7 +24,7 @@ pub struct ServerLambdaLoadError(String);
|
||||
/// let result_str = round_trip(module_to_compile);
|
||||
/// ```
|
||||
///
|
||||
pub fn round_trip(mlir_code: &str) -> Result<String, CompilationError> {
|
||||
pub fn round_trip(mlir_code: &str) -> Result<String, CompilerError> {
|
||||
unsafe {
|
||||
let engine = compilerEngineCreate();
|
||||
let mlir_code_buffer = mlir_code.as_bytes();
|
||||
@@ -38,7 +37,7 @@ pub fn round_trip(mlir_code: &str) -> Result<String, CompilationError> {
|
||||
CompilationTarget_ROUND_TRIP,
|
||||
);
|
||||
if compilationResultIsNull(compilation_result) {
|
||||
return Err(CompilationError("roundtrip error".to_string()));
|
||||
return Err(CompilerError("roundtrip error".to_string()));
|
||||
}
|
||||
let module_compiled = compilationResultGetModuleString(compilation_result);
|
||||
let result_str = String::from_utf8_lossy(std::slice::from_raw_parts(
|
||||
@@ -57,26 +56,39 @@ pub struct LibrarySupport {
|
||||
support: crate::mlir::ffi::LibrarySupport,
|
||||
}
|
||||
|
||||
impl Drop for LibrarySupport {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
librarySupportDestroy(self.support);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl LibrarySupport {
|
||||
/// LibrarySupport manages build files generated by the compiler under the `output_dir_path`.
|
||||
///
|
||||
/// The compiled library needs to link to the runtime for proper execution.
|
||||
pub fn new(output_dir_path: &str, runtime_library_path: &str) -> LibrarySupport {
|
||||
pub fn new(
|
||||
output_dir_path: &str,
|
||||
runtime_library_path: &str,
|
||||
) -> Result<LibrarySupport, CompilerError> {
|
||||
unsafe {
|
||||
let output_dir_path_buffer = output_dir_path.as_bytes();
|
||||
let runtime_library_path_buffer = runtime_library_path.as_bytes();
|
||||
LibrarySupport {
|
||||
support: librarySupportCreateDefault(
|
||||
MlirStringRef {
|
||||
data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char,
|
||||
length: output_dir_path_buffer.len() as size_t,
|
||||
},
|
||||
MlirStringRef {
|
||||
data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char,
|
||||
length: runtime_library_path_buffer.len() as size_t,
|
||||
},
|
||||
),
|
||||
let support = librarySupportCreateDefault(
|
||||
MlirStringRef {
|
||||
data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char,
|
||||
length: output_dir_path_buffer.len() as size_t,
|
||||
},
|
||||
MlirStringRef {
|
||||
data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char,
|
||||
length: runtime_library_path_buffer.len() as size_t,
|
||||
},
|
||||
);
|
||||
if librarySupportIsNull(support) {
|
||||
return Err(CompilerError("failed creating library support".to_string()));
|
||||
}
|
||||
Ok(LibrarySupport { support })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,7 +97,7 @@ impl LibrarySupport {
|
||||
&self,
|
||||
mlir_code: &str,
|
||||
options: Option<CompilationOptions>,
|
||||
) -> Result<LibraryCompilationResult, CompilationError> {
|
||||
) -> Result<LibraryCompilationResult, CompilerError> {
|
||||
unsafe {
|
||||
let options = options.unwrap_or_else(|| compilationOptionsCreateDefault());
|
||||
let mlir_code_buffer = mlir_code.as_bytes();
|
||||
@@ -98,7 +110,7 @@ impl LibrarySupport {
|
||||
options,
|
||||
);
|
||||
if libraryCompilationResultIsNull(result) {
|
||||
return Err(CompilationError("library compilation failed".to_string()));
|
||||
return Err(CompilerError("library compilation failed".to_string()));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
@@ -110,17 +122,163 @@ impl LibrarySupport {
|
||||
pub fn load_server_lambda(
|
||||
&self,
|
||||
result: LibraryCompilationResult,
|
||||
) -> Result<ServerLambda, ServerLambdaLoadError> {
|
||||
) -> Result<ServerLambda, CompilerError> {
|
||||
unsafe {
|
||||
let server = librarySupportLoadServerLambda(self.support, result);
|
||||
if serverLambdaIsNull(server) {
|
||||
return Err(ServerLambdaLoadError(
|
||||
"loading server lambda failed".to_string(),
|
||||
));
|
||||
return Err(CompilerError("loading server lambda failed".to_string()));
|
||||
}
|
||||
Ok(server)
|
||||
}
|
||||
}
|
||||
|
||||
/// Load client parameters from a compilation result.
|
||||
///
|
||||
/// This can be used for creating keys for the compiled library.
|
||||
pub fn load_client_parameters(
|
||||
&self,
|
||||
result: LibraryCompilationResult,
|
||||
) -> Result<ClientParameters, CompilerError> {
|
||||
unsafe {
|
||||
let params = librarySupportLoadClientParameters(self.support, result);
|
||||
if clientParametersIsNull(params) {
|
||||
return Err(CompilerError(
|
||||
"loading client parameters failed".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(params)
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a compiled circuit.
|
||||
pub fn server_lambda_call(
|
||||
&self,
|
||||
server_lambda: ServerLambda,
|
||||
args: PublicArguments,
|
||||
eval_keys: EvaluationKeys,
|
||||
) -> Result<PublicResult, CompilerError> {
|
||||
unsafe {
|
||||
let result = librarySupportServerCall(self.support, server_lambda, args, eval_keys);
|
||||
if publicResultIsNull(result) {
|
||||
return Err(CompilerError("failed calling server lambda".to_string()));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Support for keygen, encryption, and decryption.
|
||||
///
|
||||
/// Manages cache for keys if provided during creation.
|
||||
pub struct ClientSupport {
|
||||
client_params: crate::mlir::ffi::ClientParameters,
|
||||
key_set_cache: Option<KeySetCache>,
|
||||
}
|
||||
|
||||
impl Drop for ClientSupport {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
clientParametersDestroy(self.client_params);
|
||||
match self.key_set_cache {
|
||||
Some(cache) => keySetCacheDestroy(cache),
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientSupport {
|
||||
pub fn new(
|
||||
client_params: ClientParameters,
|
||||
key_set_cache_path: Option<&Path>,
|
||||
) -> Result<ClientSupport, CompilerError> {
|
||||
unsafe {
|
||||
let key_set_cache = match key_set_cache_path {
|
||||
Some(path) => {
|
||||
let cache_path_buffer = path.to_str().unwrap().as_bytes();
|
||||
let cache = keySetCacheCreate(MlirStringRef {
|
||||
data: cache_path_buffer.as_ptr() as *const std::os::raw::c_char,
|
||||
length: cache_path_buffer.len() as size_t,
|
||||
});
|
||||
if keySetCacheIsNull(cache) {
|
||||
return Err(CompilerError(
|
||||
"failed creating keyset cache from path".to_string(),
|
||||
));
|
||||
}
|
||||
Some(cache)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
Ok(ClientSupport {
|
||||
client_params,
|
||||
key_set_cache,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch a keyset based on the client parameters, and the different seeds.
|
||||
///
|
||||
/// If a cache has already been set, this operation would first try to load an existing key,
|
||||
/// and generate a new one if no compatible keyset exists.
|
||||
pub fn keyset(
|
||||
&self,
|
||||
seed_msb: Option<u64>,
|
||||
seed_lsb: Option<u64>,
|
||||
) -> Result<KeySet, CompilerError> {
|
||||
unsafe {
|
||||
let key_set = match self.key_set_cache {
|
||||
Some(cache) => keySetCacheLoadOrGenerateKeySet(
|
||||
cache,
|
||||
self.client_params,
|
||||
seed_msb.unwrap_or(0),
|
||||
seed_lsb.unwrap_or(0),
|
||||
),
|
||||
None => keySetGenerate(
|
||||
self.client_params,
|
||||
seed_msb.unwrap_or(0),
|
||||
seed_lsb.unwrap_or(0),
|
||||
),
|
||||
};
|
||||
if keySetIsNull(key_set) {
|
||||
return Err(CompilerError("getting keyset failed".to_string()));
|
||||
}
|
||||
Ok(key_set)
|
||||
}
|
||||
}
|
||||
|
||||
/// Encrypt arguments of a compiled circuit.
|
||||
pub fn encrypt_args(
|
||||
&self,
|
||||
args: &[LambdaArgument],
|
||||
key_set: KeySet,
|
||||
) -> Result<PublicArguments, CompilerError> {
|
||||
unsafe {
|
||||
let public_args = lambdaArgumentEncrypt(
|
||||
args.as_ptr(),
|
||||
args.len() as u64,
|
||||
self.client_params,
|
||||
key_set,
|
||||
);
|
||||
if publicArgumentsIsNull(public_args) {
|
||||
return Err(CompilerError("encryption failed".to_string()));
|
||||
}
|
||||
Ok(public_args)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_result(
|
||||
&self,
|
||||
result: PublicResult,
|
||||
key_set: KeySet,
|
||||
) -> Result<LambdaArgument, CompilerError> {
|
||||
unsafe {
|
||||
let arg = publicResultDecrypt(result, key_set);
|
||||
if lambdaArgumentIsNull(arg) {
|
||||
return Err(CompilerError("decryption failed".to_string()));
|
||||
}
|
||||
Ok(arg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -152,7 +310,7 @@ mod test {
|
||||
fn test_compiler_round_trip_invalid_mlir() {
|
||||
let module_to_compile = "bla bla bla";
|
||||
let result_str = round_trip(module_to_compile);
|
||||
assert!(matches!(result_str, Err(CompilationError(_))));
|
||||
assert!(matches!(result_str, Err(CompilerError(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -171,7 +329,8 @@ mod test {
|
||||
let support = LibrarySupport::new(
|
||||
temp_dir.path().to_str().unwrap(),
|
||||
runtime_library_path.as_str(),
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
let lib = support.compile(module_to_compile, None).unwrap();
|
||||
assert!(!libraryCompilationResultIsNull(lib));
|
||||
libraryCompilationResultDestroy(lib);
|
||||
@@ -192,7 +351,7 @@ mod test {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compiler_load_server_lambda() {
|
||||
fn test_compiler_load_server_lambda_and_client_parameters() {
|
||||
unsafe {
|
||||
let module_to_compile = "
|
||||
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
@@ -207,12 +366,62 @@ mod test {
|
||||
let support = LibrarySupport::new(
|
||||
temp_dir.path().to_str().unwrap(),
|
||||
runtime_library_path.as_str(),
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
let result = support.compile(module_to_compile, None).unwrap();
|
||||
let server = support.load_server_lambda(result).unwrap();
|
||||
assert!(!serverLambdaIsNull(server));
|
||||
libraryCompilationResultDestroy(result);
|
||||
serverLambdaDestroy(server);
|
||||
let client_params = support.load_client_parameters(result).unwrap();
|
||||
assert!(!clientParametersIsNull(client_params));
|
||||
libraryCompilationResultDestroy(result);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compiler_compile_and_exec_scalar_args() {
|
||||
unsafe {
|
||||
let module_to_compile = "
|
||||
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
return %0 : !FHE.eint<5>
|
||||
}";
|
||||
let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
|
||||
Ok(val) => val + "/lib/libConcretelangRuntime.so",
|
||||
Err(_e) => "".to_string(),
|
||||
};
|
||||
let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_scalar_args").unwrap();
|
||||
let lib_support = LibrarySupport::new(
|
||||
temp_dir.path().to_str().unwrap(),
|
||||
runtime_library_path.as_str(),
|
||||
)
|
||||
.unwrap();
|
||||
// compile
|
||||
let result = lib_support.compile(module_to_compile, None).unwrap();
|
||||
// loading materials from compilation
|
||||
// - server_lambda: used for execution
|
||||
// - client_parameters: used for keygen, encryption, and evaluation keys
|
||||
let server_lambda = lib_support.load_server_lambda(result).unwrap();
|
||||
let client_params = lib_support.load_client_parameters(result).unwrap();
|
||||
let client_support = ClientSupport::new(client_params, None).unwrap();
|
||||
let key_set = client_support.keyset(None, None).unwrap();
|
||||
let eval_keys = keySetGetEvaluationKeys(key_set);
|
||||
// build lambda arguments from scalar and encrypt them
|
||||
let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)];
|
||||
let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap();
|
||||
// free args
|
||||
args.map(|arg| lambdaArgumentDestroy(arg));
|
||||
// execute the compiled function on the encrypted arguments
|
||||
let encrypted_result = lib_support
|
||||
.server_lambda_call(server_lambda, encrypted_args, eval_keys)
|
||||
.unwrap();
|
||||
// decrypt the result of execution
|
||||
let result_arg = client_support
|
||||
.decrypt_result(encrypted_result, key_set)
|
||||
.unwrap();
|
||||
// get the scalar value from the result lambda argument
|
||||
let result = lambdaArgumentGetScalar(result_arg);
|
||||
assert_eq!(result, 6);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "concretelang/CAPI/Wrappers.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/LambdaSupport.h"
|
||||
#include "mlir/IR/Diagnostics.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
|
||||
@@ -186,7 +187,7 @@ LibraryCompilationResult librarySupportCompile(LibrarySupport support,
|
||||
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL);
|
||||
}
|
||||
return wrap(new mlir::concretelang::LibraryCompilationResult(
|
||||
*retOrError.get().get()));
|
||||
*retOrError.get().release()));
|
||||
}
|
||||
|
||||
ServerLambda librarySupportLoadServerLambda(LibrarySupport support,
|
||||
@@ -200,6 +201,169 @@ ServerLambda librarySupportLoadServerLambda(LibrarySupport support,
|
||||
serverLambdaOrError.get()));
|
||||
}
|
||||
|
||||
ClientParameters
|
||||
librarySupportLoadClientParameters(LibrarySupport support,
|
||||
LibraryCompilationResult result) {
|
||||
auto paramsOrError = unwrap(support)->loadClientParameters(*unwrap(result));
|
||||
if (!paramsOrError) {
|
||||
llvm::errs() << llvm::toString(paramsOrError.takeError());
|
||||
return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL);
|
||||
}
|
||||
return wrap(
|
||||
new mlir::concretelang::clientlib::ClientParameters(paramsOrError.get()));
|
||||
}
|
||||
|
||||
PublicResult librarySupportServerCall(LibrarySupport support,
|
||||
ServerLambda server_lambda,
|
||||
PublicArguments args,
|
||||
EvaluationKeys evalKeys) {
|
||||
auto resultOrError = unwrap(support)->serverCall(
|
||||
*unwrap(server_lambda), *unwrap(args), *unwrap(evalKeys));
|
||||
if (!resultOrError) {
|
||||
llvm::errs() << llvm::toString(resultOrError.takeError());
|
||||
return wrap((mlir::concretelang::clientlib::PublicResult *)NULL);
|
||||
}
|
||||
return wrap(resultOrError.get().release());
|
||||
}
|
||||
|
||||
void librarySupportDestroy(LibrarySupport support) { delete unwrap(support); }
|
||||
|
||||
/// ********** ServerLamda CAPI ************************************************
|
||||
|
||||
void serverLambdaDestroy(ServerLambda server) { delete unwrap(server); }
|
||||
|
||||
/// ********** ClientParameters CAPI *******************************************
|
||||
|
||||
void clientParametersDestroy(ClientParameters params) { delete unwrap(params); }
|
||||
|
||||
/// ********** KeySet CAPI *****************************************************
|
||||
|
||||
KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
auto keySet = mlir::concretelang::clientlib::KeySet::generate(
|
||||
*unwrap(params), seed_msb, seed_lsb);
|
||||
if (keySet.has_error()) {
|
||||
llvm::errs() << keySet.error().mesg;
|
||||
return wrap((mlir::concretelang::clientlib::KeySet *)NULL);
|
||||
}
|
||||
return wrap(keySet.value().release());
|
||||
}
|
||||
|
||||
EvaluationKeys keySetGetEvaluationKeys(KeySet keySet) {
|
||||
return wrap(new mlir::concretelang::clientlib::EvaluationKeys(
|
||||
unwrap(keySet)->evaluationKeys()));
|
||||
}
|
||||
|
||||
void keySetDestroy(KeySet keySet) { delete unwrap(keySet); }
|
||||
|
||||
/// ********** KeySetCache CAPI ************************************************
|
||||
|
||||
KeySetCache keySetCacheCreate(MlirStringRef cachePath) {
|
||||
std::string cachePathStr(cachePath.data, cachePath.length);
|
||||
return wrap(new mlir::concretelang::clientlib::KeySetCache(cachePathStr));
|
||||
}
|
||||
|
||||
KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache,
|
||||
ClientParameters params,
|
||||
uint64_t seed_msb, uint64_t seed_lsb) {
|
||||
auto keySetOrError =
|
||||
unwrap(cache)->generate(*unwrap(params), seed_msb, seed_lsb);
|
||||
if (keySetOrError.has_error()) {
|
||||
llvm::errs() << keySetOrError.error().mesg;
|
||||
return wrap((mlir::concretelang::clientlib::KeySet *)NULL);
|
||||
}
|
||||
return wrap(keySetOrError.value().release());
|
||||
}
|
||||
|
||||
void keySetCacheDestroy(KeySetCache keySetCache) { delete unwrap(keySetCache); }
|
||||
|
||||
/// ********** EvaluationKeys CAPI *********************************************
|
||||
|
||||
void evaluationKeysDestroy(EvaluationKeys evaluationKeys) {
|
||||
delete unwrap(evaluationKeys);
|
||||
}
|
||||
|
||||
/// ********** LambdaArgument CAPI *********************************************
|
||||
|
||||
LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
|
||||
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
|
||||
}
|
||||
|
||||
// LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
|
||||
// size_t rank);
|
||||
|
||||
bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) {
|
||||
return unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
}
|
||||
|
||||
uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg) {
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t> *arg =
|
||||
unwrap(lambdaArg)
|
||||
->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
assert(arg != nullptr && "lambda argument isn't a scalar");
|
||||
return arg->getValue();
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) {
|
||||
return unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
|
||||
unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
|
||||
unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
|
||||
unwrap(lambdaArg)
|
||||
->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
}
|
||||
|
||||
// uint64_t *lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
|
||||
// size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
|
||||
// int64_t *lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
|
||||
|
||||
PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
|
||||
size_t argNumber, ClientParameters params,
|
||||
KeySet keySet) {
|
||||
std::vector<const mlir::concretelang::LambdaArgument *> args;
|
||||
for (size_t i = 0; i < argNumber; i++)
|
||||
args.push_back(unwrap(lambdaArgs[i]));
|
||||
auto publicArgsOrError =
|
||||
mlir::concretelang::LambdaSupport<int, int>::exportArguments(
|
||||
*unwrap(params), *unwrap(keySet), args);
|
||||
if (!publicArgsOrError) {
|
||||
llvm::errs() << llvm::toString(publicArgsOrError.takeError());
|
||||
return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL);
|
||||
}
|
||||
return wrap(publicArgsOrError.get().release());
|
||||
}
|
||||
|
||||
void lambdaArgumentDestroy(LambdaArgument lambdaArg) {
|
||||
delete unwrap(lambdaArg);
|
||||
}
|
||||
|
||||
/// ********** PublicArguments CAPI ********************************************
|
||||
|
||||
void publicArgumentsDestroy(PublicArguments publicArgs) {
|
||||
delete unwrap(publicArgs);
|
||||
}
|
||||
|
||||
/// ********** PublicResult CAPI ***********************************************
|
||||
|
||||
LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) {
|
||||
llvm::Expected<std::unique_ptr<mlir::concretelang::LambdaArgument>>
|
||||
lambdaArgOrError = mlir::concretelang::typedResult<
|
||||
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
|
||||
*unwrap(keySet), *unwrap(publicResult));
|
||||
if (!lambdaArgOrError) {
|
||||
llvm::errs() << llvm::toString(lambdaArgOrError.takeError());
|
||||
return wrap((mlir::concretelang::LambdaArgument *)NULL);
|
||||
}
|
||||
return wrap(lambdaArgOrError.get().release());
|
||||
}
|
||||
|
||||
void publicResultDestroy(PublicResult publicResult) {
|
||||
delete unwrap(publicResult);
|
||||
}
|
||||
|
||||
@@ -353,5 +353,11 @@ KeySetCache::generate(std::shared_ptr<KeySetCache> cache,
|
||||
: KeySet::generate(params, seed_msb, seed_lsb);
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
KeySetCache::generate(ClientParameters ¶ms, uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
return loadOrGenerateSave(params, seed_msb, seed_lsb);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
Reference in New Issue
Block a user