feat(rust): load server lambda for later execution

This commit is contained in:
youben11
2022-11-24 11:59:39 +01:00
committed by Ayoub Benaissa
parent b00115f4ae
commit 7f55385ea2
4 changed files with 74 additions and 1 deletions

View File

@@ -32,6 +32,7 @@ DEFINE_C_API_STRUCT(LibraryCompilationResult, void);
DEFINE_C_API_STRUCT(LibrarySupport, void);
DEFINE_C_API_STRUCT(CompilationOptions, void);
DEFINE_C_API_STRUCT(OptimizerConfig, void);
DEFINE_C_API_STRUCT(ServerLambda, void);
#undef DEFINE_C_API_STRUCT
@@ -49,6 +50,7 @@ DEFINE_NULL_PTR_CHECKER(libraryCompilationResultIsNull,
DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport);
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions);
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig);
DEFINE_NULL_PTR_CHECKER(serverLambdaIsNull, ServerLambda);
#undef DEFINE_NULL_PTR_CHECKER
@@ -125,6 +127,7 @@ MLIR_CAPI_EXPORTED void
libraryCompilationResultDestroy(LibraryCompilationResult result);
/// ********** LibrarySupport CAPI *********************************************
MLIR_CAPI_EXPORTED LibrarySupport
librarySupportCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool generateSharedLib,
@@ -140,6 +143,13 @@ MLIR_CAPI_EXPORTED LibrarySupport librarySupportCreateDefault(
MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile(
LibrarySupport support, MlirStringRef module, CompilationOptions options);
MLIR_CAPI_EXPORTED ServerLambda librarySupportLoadServerLambda(
LibrarySupport support, LibraryCompilationResult result);
/// ********** ServerLamda CAPI ************************************************
MLIR_CAPI_EXPORTED void serverLambdaDestroy(ServerLambda server);
#ifdef __cplusplus
}
#endif

View File

@@ -23,5 +23,7 @@ DEFINE_C_API_PTR_METHODS(LibrarySupport, mlir::concretelang::LibrarySupport)
DEFINE_C_API_PTR_METHODS(CompilationOptions,
mlir::concretelang::CompilationOptions)
DEFINE_C_API_PTR_METHODS(OptimizerConfig, mlir::concretelang::optimizer::Config)
DEFINE_C_API_PTR_METHODS(ServerLambda,
mlir::concretelang::serverlib::ServerLambda)
#endif

View File

@@ -5,6 +5,9 @@ use crate::mlir::ffi::*;
#[derive(Debug)]
pub struct CompilationError(String);
#[derive(Debug)]
pub struct ServerLambdaLoadError(String);
/// Parse the MLIR code and returns it.
///
/// The function parse the provided MLIR textual representation and returns it. It would fail with
@@ -100,6 +103,24 @@ impl LibrarySupport {
Ok(result)
}
}
/// Load server lambda from a compilation result.
///
/// This can be used for executing the compiled function.
pub fn load_server_lambda(
&self,
result: LibraryCompilationResult,
) -> Result<ServerLambda, ServerLambdaLoadError> {
unsafe {
let server = librarySupportLoadServerLambda(self.support, result);
if serverLambdaIsNull(server) {
return Err(ServerLambdaLoadError(
"loading server lambda failed".to_string(),
));
}
Ok(server)
}
}
}
#[cfg(test)]
@@ -169,4 +190,29 @@ mod test {
assert!(libraryIsNull(lib));
}
}
#[test]
fn test_compiler_load_server_lambda() {
unsafe {
let module_to_compile = "
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}";
let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
Ok(val) => val + "/lib/libConcretelangRuntime.so",
Err(_e) => "".to_string(),
};
let temp_dir = TempDir::new("rust_test_compiler_load_server_lambda").unwrap();
let support = LibrarySupport::new(
temp_dir.path().to_str().unwrap(),
runtime_library_path.as_str(),
);
let result = support.compile(module_to_compile, None).unwrap();
let server = support.load_server_lambda(result).unwrap();
assert!(!serverLambdaIsNull(server));
libraryCompilationResultDestroy(result);
serverLambdaDestroy(server);
}
}
}

View File

@@ -187,4 +187,19 @@ LibraryCompilationResult librarySupportCompile(LibrarySupport support,
}
return wrap(new mlir::concretelang::LibraryCompilationResult(
*retOrError.get().get()));
}
}
ServerLambda librarySupportLoadServerLambda(LibrarySupport support,
LibraryCompilationResult result) {
auto serverLambdaOrError = unwrap(support)->loadServerLambda(*unwrap(result));
if (!serverLambdaOrError) {
llvm::errs() << llvm::toString(serverLambdaOrError.takeError());
return wrap((mlir::concretelang::serverlib::ServerLambda *)NULL);
}
return wrap(new mlir::concretelang::serverlib::ServerLambda(
serverLambdaOrError.get()));
}
/// ********** ServerLamda CAPI ************************************************
void serverLambdaDestroy(ServerLambda server) { delete unwrap(server); }