feat(rust): support keygen, encryption, execution

This commit is contained in:
youben11
2022-11-25 14:16:47 +01:00
committed by Ayoub Benaissa
parent 7f55385ea2
commit 15b4aac0a1
6 changed files with 509 additions and 30 deletions

View File

@@ -33,6 +33,13 @@ DEFINE_C_API_STRUCT(LibrarySupport, void);
DEFINE_C_API_STRUCT(CompilationOptions, void);
DEFINE_C_API_STRUCT(OptimizerConfig, void);
DEFINE_C_API_STRUCT(ServerLambda, void);
DEFINE_C_API_STRUCT(ClientParameters, void);
DEFINE_C_API_STRUCT(KeySet, void);
DEFINE_C_API_STRUCT(KeySetCache, void);
DEFINE_C_API_STRUCT(EvaluationKeys, void);
DEFINE_C_API_STRUCT(LambdaArgument, void);
DEFINE_C_API_STRUCT(PublicArguments, void);
DEFINE_C_API_STRUCT(PublicResult, void);
#undef DEFINE_C_API_STRUCT
@@ -51,6 +58,13 @@ DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport);
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions);
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig);
DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda);
DEFINE_NULL_PTR_CHECKER(clientParametersIsNull, ClientParameters);
DEFINE_NULL_PTR_CHECKER(keySetIsNull, KeySet);
DEFINE_NULL_PTR_CHECKER(keySetCacheIsNull, KeySetCache);
DEFINE_NULL_PTR_CHECKER(evaluationKeysIsNull, EvaluationKeys);
DEFINE_NULL_PTR_CHECKER(lambdaArgumentIsNull, LambdaArgument);
DEFINE_NULL_PTR_CHECKER(publicArgumentsIsNull, PublicArguments);
DEFINE_NULL_PTR_CHECKER(publicResultIsNull, PublicResult);
#undef DEFINE_NULL_PTR_CHECKER
@@ -146,10 +160,81 @@ MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile(
MLIR_CAPI_EXPORTED ServerLambda librarySupportLoadServerLambda(
LibrarySupport support, LibraryCompilationResult result);
MLIR_CAPI_EXPORTED ClientParameters librarySupportLoadClientParameters(
LibrarySupport support, LibraryCompilationResult result);
MLIR_CAPI_EXPORTED PublicResult
librarySupportServerCall(LibrarySupport support, ServerLambda server,
PublicArguments args, EvaluationKeys evalKeys);
MLIR_CAPI_EXPORTED void librarySupportDestroy(LibrarySupport support);
/// ********** ServerLamda CAPI ************************************************
MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server);
/// ********** ClientParameters CAPI *******************************************
MLIR_CAPI_EXPORTED void clientParametersDestroy(ClientParameters params);
/// ********** KeySet CAPI *****************************************************
MLIR_CAPI_EXPORTED KeySet keySetGenerate(ClientParameters params,
uint64_t seed_msb, uint64_t seed_lsb);
MLIR_CAPI_EXPORTED EvaluationKeys keySetGetEvaluationKeys(KeySet keySet);
MLIR_CAPI_EXPORTED void keySetDestroy(KeySet keySet);
/// ********** KeySetCache CAPI ************************************************
MLIR_CAPI_EXPORTED KeySetCache keySetCacheCreate(MlirStringRef cachePath);
MLIR_CAPI_EXPORTED KeySet
keySetCacheLoadOrGenerateKeySet(KeySetCache cache, ClientParameters params,
uint64_t seed_msb, uint64_t seed_lsb);
MLIR_CAPI_EXPORTED void keySetCacheDestroy(KeySetCache keySetCache);
/// ********** EvaluationKeys CAPI *********************************************
MLIR_CAPI_EXPORTED void evaluationKeysDestroy(EvaluationKeys evaluationKeys);
/// ********** LambdaArgument CAPI *********************************************
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromScalar(uint64_t value);
MLIR_CAPI_EXPORTED LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data,
int64_t *dims,
size_t rank);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED uint64_t *
lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED int64_t *
lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
MLIR_CAPI_EXPORTED PublicArguments
lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, size_t argNumber,
ClientParameters params, KeySet keySet);
MLIR_CAPI_EXPORTED void lambdaArgumentDestroy(LambdaArgument lambdaArg);
/// ********** PublicArguments CAPI ********************************************
MLIR_CAPI_EXPORTED void publicArgumentsDestroy(PublicArguments publicArgs);
/// ********** PublicResult CAPI ***********************************************
MLIR_CAPI_EXPORTED LambdaArgument publicResultDecrypt(PublicResult publicResult,
KeySet keySet);
MLIR_CAPI_EXPORTED void publicResultDestroy(PublicResult publicResult);
#ifdef __cplusplus
}
#endif

View File

@@ -25,5 +25,17 @@ DEFINE_C_API_PTR_METHODS(CompilationOptions,
DEFINE_C_API_PTR_METHODS(OptimizerConfig, mlir::concretelang::optimizer::Config)
DEFINE_C_API_PTR_METHODS(ServerLambda,
mlir::concretelang::serverlib::ServerLambda)
DEFINE_C_API_PTR_METHODS(ClientParameters,
mlir::concretelang::clientlib::ClientParameters)
DEFINE_C_API_PTR_METHODS(KeySet, mlir::concretelang::clientlib::KeySet)
DEFINE_C_API_PTR_METHODS(KeySetCache,
mlir::concretelang::clientlib::KeySetCache)
DEFINE_C_API_PTR_METHODS(EvaluationKeys,
mlir::concretelang::clientlib::EvaluationKeys)
DEFINE_C_API_PTR_METHODS(LambdaArgument, mlir::concretelang::LambdaArgument)
DEFINE_C_API_PTR_METHODS(PublicArguments,
mlir::concretelang::clientlib::PublicArguments)
DEFINE_C_API_PTR_METHODS(PublicResult,
mlir::concretelang::clientlib::PublicResult)
#endif

View File

@@ -24,6 +24,9 @@ public:
generate(std::shared_ptr<KeySetCache> optionalCache, ClientParameters &params,
uint64_t seed_msb, uint64_t seed_lsb);
outcome::checked<std::unique_ptr<KeySet>, StringError>
generate(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb);
private:
static outcome::checked<std::unique_ptr<KeySet>, StringError>
loadKeys(ClientParameters &params, uint64_t seed_msb, uint64_t seed_lsb,

View File

@@ -1,12 +1,11 @@
//! Compiler module
use std::path::Path;
use crate::mlir::ffi::*;
#[derive(Debug)]
pub struct CompilationError(String);
#[derive(Debug)]
pub struct ServerLambdaLoadError(String);
pub struct CompilerError(String);
/// Parse the MLIR code and returns it.
///
@@ -25,7 +24,7 @@ pub struct ServerLambdaLoadError(String);
/// let result_str = round_trip(module_to_compile);
/// ```
///
pub fn round_trip(mlir_code: &str) -> Result<String, CompilationError> {
pub fn round_trip(mlir_code: &str) -> Result<String, CompilerError> {
unsafe {
let engine = compilerEngineCreate();
let mlir_code_buffer = mlir_code.as_bytes();
@@ -38,7 +37,7 @@ pub fn round_trip(mlir_code: &str) -> Result<String, CompilationError> {
CompilationTarget_ROUND_TRIP,
);
if compilationResultIsNull(compilation_result) {
return Err(CompilationError("roundtrip error".to_string()));
return Err(CompilerError("roundtrip error".to_string()));
}
let module_compiled = compilationResultGetModuleString(compilation_result);
let result_str = String::from_utf8_lossy(std::slice::from_raw_parts(
@@ -57,26 +56,39 @@ pub struct LibrarySupport {
support: crate::mlir::ffi::LibrarySupport,
}
impl Drop for LibrarySupport {
fn drop(&mut self) {
unsafe {
librarySupportDestroy(self.support);
}
}
}
impl LibrarySupport {
/// LibrarySupport manages build files generated by the compiler under the `output_dir_path`.
///
/// The compiled library needs to link to the runtime for proper execution.
pub fn new(output_dir_path: &str, runtime_library_path: &str) -> LibrarySupport {
pub fn new(
output_dir_path: &str,
runtime_library_path: &str,
) -> Result<LibrarySupport, CompilerError> {
unsafe {
let output_dir_path_buffer = output_dir_path.as_bytes();
let runtime_library_path_buffer = runtime_library_path.as_bytes();
LibrarySupport {
support: librarySupportCreateDefault(
MlirStringRef {
data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char,
length: output_dir_path_buffer.len() as size_t,
},
MlirStringRef {
data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char,
length: runtime_library_path_buffer.len() as size_t,
},
),
let support = librarySupportCreateDefault(
MlirStringRef {
data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char,
length: output_dir_path_buffer.len() as size_t,
},
MlirStringRef {
data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char,
length: runtime_library_path_buffer.len() as size_t,
},
);
if librarySupportIsNull(support) {
return Err(CompilerError("failed creating library support".to_string()));
}
Ok(LibrarySupport { support })
}
}
@@ -85,7 +97,7 @@ impl LibrarySupport {
&self,
mlir_code: &str,
options: Option<CompilationOptions>,
) -> Result<LibraryCompilationResult, CompilationError> {
) -> Result<LibraryCompilationResult, CompilerError> {
unsafe {
let options = options.unwrap_or_else(|| compilationOptionsCreateDefault());
let mlir_code_buffer = mlir_code.as_bytes();
@@ -98,7 +110,7 @@ impl LibrarySupport {
options,
);
if libraryCompilationResultIsNull(result) {
return Err(CompilationError("library compilation failed".to_string()));
return Err(CompilerError("library compilation failed".to_string()));
}
Ok(result)
}
@@ -110,17 +122,163 @@ impl LibrarySupport {
pub fn load_server_lambda(
&self,
result: LibraryCompilationResult,
) -> Result<ServerLambda, ServerLambdaLoadError> {
) -> Result<ServerLambda, CompilerError> {
unsafe {
let server = librarySupportLoadServerLambda(self.support, result);
if serverLambdaIsNull(server) {
return Err(ServerLambdaLoadError(
"loading server lambda failed".to_string(),
));
return Err(CompilerError("loading server lambda failed".to_string()));
}
Ok(server)
}
}
/// Load client parameters from a compilation result.
///
/// This can be used for creating keys for the compiled library.
pub fn load_client_parameters(
&self,
result: LibraryCompilationResult,
) -> Result<ClientParameters, CompilerError> {
unsafe {
let params = librarySupportLoadClientParameters(self.support, result);
if clientParametersIsNull(params) {
return Err(CompilerError(
"loading client parameters failed".to_string(),
));
}
Ok(params)
}
}
/// Run a compiled circuit.
pub fn server_lambda_call(
&self,
server_lambda: ServerLambda,
args: PublicArguments,
eval_keys: EvaluationKeys,
) -> Result<PublicResult, CompilerError> {
unsafe {
let result = librarySupportServerCall(self.support, server_lambda, args, eval_keys);
if publicResultIsNull(result) {
return Err(CompilerError("failed calling server lambda".to_string()));
}
Ok(result)
}
}
}
/// Support for keygen, encryption, and decryption.
///
/// Manages cache for keys if provided during creation.
pub struct ClientSupport {
client_params: crate::mlir::ffi::ClientParameters,
key_set_cache: Option<KeySetCache>,
}
impl Drop for ClientSupport {
fn drop(&mut self) {
unsafe {
clientParametersDestroy(self.client_params);
match self.key_set_cache {
Some(cache) => keySetCacheDestroy(cache),
None => (),
}
}
}
}
impl ClientSupport {
pub fn new(
client_params: ClientParameters,
key_set_cache_path: Option<&Path>,
) -> Result<ClientSupport, CompilerError> {
unsafe {
let key_set_cache = match key_set_cache_path {
Some(path) => {
let cache_path_buffer = path.to_str().unwrap().as_bytes();
let cache = keySetCacheCreate(MlirStringRef {
data: cache_path_buffer.as_ptr() as *const std::os::raw::c_char,
length: cache_path_buffer.len() as size_t,
});
if keySetCacheIsNull(cache) {
return Err(CompilerError(
"failed creating keyset cache from path".to_string(),
));
}
Some(cache)
}
None => None,
};
Ok(ClientSupport {
client_params,
key_set_cache,
})
}
}
/// Fetch a keyset based on the client parameters, and the different seeds.
///
/// If a cache has already been set, this operation would first try to load an existing key,
/// and generate a new one if no compatible keyset exists.
pub fn keyset(
&self,
seed_msb: Option<u64>,
seed_lsb: Option<u64>,
) -> Result<KeySet, CompilerError> {
unsafe {
let key_set = match self.key_set_cache {
Some(cache) => keySetCacheLoadOrGenerateKeySet(
cache,
self.client_params,
seed_msb.unwrap_or(0),
seed_lsb.unwrap_or(0),
),
None => keySetGenerate(
self.client_params,
seed_msb.unwrap_or(0),
seed_lsb.unwrap_or(0),
),
};
if keySetIsNull(key_set) {
return Err(CompilerError("getting keyset failed".to_string()));
}
Ok(key_set)
}
}
/// Encrypt arguments of a compiled circuit.
pub fn encrypt_args(
&self,
args: &[LambdaArgument],
key_set: KeySet,
) -> Result<PublicArguments, CompilerError> {
unsafe {
let public_args = lambdaArgumentEncrypt(
args.as_ptr(),
args.len() as u64,
self.client_params,
key_set,
);
if publicArgumentsIsNull(public_args) {
return Err(CompilerError("encryption failed".to_string()));
}
Ok(public_args)
}
}
pub fn decrypt_result(
&self,
result: PublicResult,
key_set: KeySet,
) -> Result<LambdaArgument, CompilerError> {
unsafe {
let arg = publicResultDecrypt(result, key_set);
if lambdaArgumentIsNull(arg) {
return Err(CompilerError("decryption failed".to_string()));
}
Ok(arg)
}
}
}
#[cfg(test)]
@@ -152,7 +310,7 @@ mod test {
fn test_compiler_round_trip_invalid_mlir() {
let module_to_compile = "bla bla bla";
let result_str = round_trip(module_to_compile);
assert!(matches!(result_str, Err(CompilationError(_))));
assert!(matches!(result_str, Err(CompilerError(_))));
}
#[test]
@@ -171,7 +329,8 @@ mod test {
let support = LibrarySupport::new(
temp_dir.path().to_str().unwrap(),
runtime_library_path.as_str(),
);
)
.unwrap();
let lib = support.compile(module_to_compile, None).unwrap();
assert!(!libraryCompilationResultIsNull(lib));
libraryCompilationResultDestroy(lib);
@@ -192,7 +351,7 @@ mod test {
}
#[test]
fn test_compiler_load_server_lambda() {
fn test_compiler_load_server_lambda_and_client_parameters() {
unsafe {
let module_to_compile = "
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
@@ -207,12 +366,62 @@ mod test {
let support = LibrarySupport::new(
temp_dir.path().to_str().unwrap(),
runtime_library_path.as_str(),
);
)
.unwrap();
let result = support.compile(module_to_compile, None).unwrap();
let server = support.load_server_lambda(result).unwrap();
assert!(!serverLambdaIsNull(server));
libraryCompilationResultDestroy(result);
serverLambdaDestroy(server);
let client_params = support.load_client_parameters(result).unwrap();
assert!(!clientParametersIsNull(client_params));
libraryCompilationResultDestroy(result);
}
}
#[test]
fn test_compiler_compile_and_exec_scalar_args() {
unsafe {
let module_to_compile = "
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}";
let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
Ok(val) => val + "/lib/libConcretelangRuntime.so",
Err(_e) => "".to_string(),
};
let temp_dir = TempDir::new("rust_test_compiler_compile_and_exec_scalar_args").unwrap();
let lib_support = LibrarySupport::new(
temp_dir.path().to_str().unwrap(),
runtime_library_path.as_str(),
)
.unwrap();
// compile
let result = lib_support.compile(module_to_compile, None).unwrap();
// loading materials from compilation
// - server_lambda: used for execution
// - client_parameters: used for keygen, encryption, and evaluation keys
let server_lambda = lib_support.load_server_lambda(result).unwrap();
let client_params = lib_support.load_client_parameters(result).unwrap();
let client_support = ClientSupport::new(client_params, None).unwrap();
let key_set = client_support.keyset(None, None).unwrap();
let eval_keys = keySetGetEvaluationKeys(key_set);
// build lambda arguments from scalar and encrypt them
let args = [lambdaArgumentFromScalar(4), lambdaArgumentFromScalar(2)];
let encrypted_args = client_support.encrypt_args(&args, key_set).unwrap();
// free args
args.map(|arg| lambdaArgumentDestroy(arg));
// execute the compiled function on the encrypted arguments
let encrypted_result = lib_support
.server_lambda_call(server_lambda, encrypted_args, eval_keys)
.unwrap();
// decrypt the result of execution
let result_arg = client_support
.decrypt_result(encrypted_result, key_set)
.unwrap();
// get the scalar value from the result lambda argument
let result = lambdaArgumentGetScalar(result_arg);
assert_eq!(result, 6);
}
}
}

View File

@@ -7,6 +7,7 @@
#include "concretelang/CAPI/Wrappers.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/LambdaSupport.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/Support/SourceMgr.h"
@@ -186,7 +187,7 @@ LibraryCompilationResult librarySupportCompile(LibrarySupport support,
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL);
}
return wrap(new mlir::concretelang::LibraryCompilationResult(
*retOrError.get().get()));
*retOrError.get().release()));
}
ServerLambda librarySupportLoadServerLambda(LibrarySupport support,
@@ -200,6 +201,169 @@ ServerLambda librarySupportLoadServerLambda(LibrarySupport support,
serverLambdaOrError.get()));
}
ClientParameters
librarySupportLoadClientParameters(LibrarySupport support,
LibraryCompilationResult result) {
auto paramsOrError = unwrap(support)->loadClientParameters(*unwrap(result));
if (!paramsOrError) {
llvm::errs() << llvm::toString(paramsOrError.takeError());
return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL);
}
return wrap(
new mlir::concretelang::clientlib::ClientParameters(paramsOrError.get()));
}
PublicResult librarySupportServerCall(LibrarySupport support,
ServerLambda server_lambda,
PublicArguments args,
EvaluationKeys evalKeys) {
auto resultOrError = unwrap(support)->serverCall(
*unwrap(server_lambda), *unwrap(args), *unwrap(evalKeys));
if (!resultOrError) {
llvm::errs() << llvm::toString(resultOrError.takeError());
return wrap((mlir::concretelang::clientlib::PublicResult *)NULL);
}
return wrap(resultOrError.get().release());
}
void librarySupportDestroy(LibrarySupport support) { delete unwrap(support); }
/// ********** ServerLamda CAPI ************************************************
void serverLambdaDestroy(ServerLambda server) { delete unwrap(server); }
/// ********** ClientParameters CAPI *******************************************
void clientParametersDestroy(ClientParameters params) { delete unwrap(params); }
/// ********** KeySet CAPI *****************************************************
KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,
uint64_t seed_lsb) {
auto keySet = mlir::concretelang::clientlib::KeySet::generate(
*unwrap(params), seed_msb, seed_lsb);
if (keySet.has_error()) {
llvm::errs() << keySet.error().mesg;
return wrap((mlir::concretelang::clientlib::KeySet *)NULL);
}
return wrap(keySet.value().release());
}
EvaluationKeys keySetGetEvaluationKeys(KeySet keySet) {
return wrap(new mlir::concretelang::clientlib::EvaluationKeys(
unwrap(keySet)->evaluationKeys()));
}
void keySetDestroy(KeySet keySet) { delete unwrap(keySet); }
/// ********** KeySetCache CAPI ************************************************
KeySetCache keySetCacheCreate(MlirStringRef cachePath) {
std::string cachePathStr(cachePath.data, cachePath.length);
return wrap(new mlir::concretelang::clientlib::KeySetCache(cachePathStr));
}
KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache,
ClientParameters params,
uint64_t seed_msb, uint64_t seed_lsb) {
auto keySetOrError =
unwrap(cache)->generate(*unwrap(params), seed_msb, seed_lsb);
if (keySetOrError.has_error()) {
llvm::errs() << keySetOrError.error().mesg;
return wrap((mlir::concretelang::clientlib::KeySet *)NULL);
}
return wrap(keySetOrError.value().release());
}
void keySetCacheDestroy(KeySetCache keySetCache) { delete unwrap(keySetCache); }
/// ********** EvaluationKeys CAPI *********************************************
void evaluationKeysDestroy(EvaluationKeys evaluationKeys) {
delete unwrap(evaluationKeys);
}
/// ********** LambdaArgument CAPI *********************************************
LambdaArgument lambdaArgumentFromScalar(uint64_t value) {
return wrap(new mlir::concretelang::IntLambdaArgument<uint64_t>(value));
}
// LambdaArgument lambdaArgumentFromTensorU64(uint64_t *data, int64_t *dims,
// size_t rank);
bool lambdaArgumentIsScalar(LambdaArgument lambdaArg) {
return unwrap(lambdaArg)
->isa<mlir::concretelang::IntLambdaArgument<uint64_t>>();
}
uint64_t lambdaArgumentGetScalar(LambdaArgument lambdaArg) {
mlir::concretelang::IntLambdaArgument<uint64_t> *arg =
unwrap(lambdaArg)
->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
assert(arg != nullptr && "lambda argument isn't a scalar");
return arg->getValue();
}
bool lambdaArgumentIsTensor(LambdaArgument lambdaArg) {
return unwrap(lambdaArg)
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
unwrap(lambdaArg)
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
unwrap(lambdaArg)
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
unwrap(lambdaArg)
->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
}
// uint64_t *lambdaArgumentGetTensorData(LambdaArgument lambdaArg);
// size_t lambdaArgumentGetTensorRank(LambdaArgument lambdaArg);
// int64_t *lambdaArgumentGetTensorDims(LambdaArgument lambdaArg);
PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
size_t argNumber, ClientParameters params,
KeySet keySet) {
std::vector<const mlir::concretelang::LambdaArgument *> args;
for (size_t i = 0; i < argNumber; i++)
args.push_back(unwrap(lambdaArgs[i]));
auto publicArgsOrError =
mlir::concretelang::LambdaSupport<int, int>::exportArguments(
*unwrap(params), *unwrap(keySet), args);
if (!publicArgsOrError) {
llvm::errs() << llvm::toString(publicArgsOrError.takeError());
return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL);
}
return wrap(publicArgsOrError.get().release());
}
void lambdaArgumentDestroy(LambdaArgument lambdaArg) {
delete unwrap(lambdaArg);
}
/// ********** PublicArguments CAPI ********************************************
void publicArgumentsDestroy(PublicArguments publicArgs) {
delete unwrap(publicArgs);
}
/// ********** PublicResult CAPI ***********************************************
LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) {
llvm::Expected<std::unique_ptr<mlir::concretelang::LambdaArgument>>
lambdaArgOrError = mlir::concretelang::typedResult<
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
*unwrap(keySet), *unwrap(publicResult));
if (!lambdaArgOrError) {
llvm::errs() << llvm::toString(lambdaArgOrError.takeError());
return wrap((mlir::concretelang::LambdaArgument *)NULL);
}
return wrap(lambdaArgOrError.get().release());
}
void publicResultDestroy(PublicResult publicResult) {
delete unwrap(publicResult);
}

View File

@@ -353,5 +353,11 @@ KeySetCache::generate(std::shared_ptr<KeySetCache> cache,
: KeySet::generate(params, seed_msb, seed_lsb);
}
outcome::checked<std::unique_ptr<KeySet>, StringError>
KeySetCache::generate(ClientParameters &params, uint64_t seed_msb,
uint64_t seed_lsb) {
return loadOrGenerateSave(params, seed_msb, seed_lsb);
}
} // namespace clientlib
} // namespace concretelang