From b00115f4ae2817019738928de38ca08d68ccf944 Mon Sep 17 00:00:00 2001 From: youben11 Date: Thu, 24 Nov 2022 10:42:18 +0100 Subject: [PATCH] feat(rust): compile mlir into library CAPI covering a wider API of the Support library. Better error handling. Could also be improved by returning an error message back from C to rust (left TODO). --- .../concretelang-c/Support/CompilerEngine.h | 77 +++++++- compiler/include/concretelang/CAPI/Wrappers.h | 10 ++ compiler/lib/Bindings/Rust/Cargo.toml | 2 + compiler/lib/Bindings/Rust/src/compiler.rs | 111 +++++++++++- compiler/lib/CAPI/Support/CompilerEngine.cpp | 164 ++++++++++++++++-- 5 files changed, 344 insertions(+), 20 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index 9fbe67b2c..c754ea320 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -12,6 +12,10 @@ extern "C" { #endif +// TODO: add a char* to the struct that can return an error message in case of +// an error (where ptr would be null). Error messages can be returned using +// llvm::toString(error.takeError()) and allocating a buffer for the message and +// copy it. The buffer can later be freed during struct destruction. /// Opaque type declarations. Refer to llvm-project/mlir/include/mlir-c/IR.h for /// more info #define DEFINE_C_API_STRUCT(name, storage) \ @@ -23,6 +27,11 @@ extern "C" { DEFINE_C_API_STRUCT(CompilerEngine, void); DEFINE_C_API_STRUCT(CompilationContext, void); DEFINE_C_API_STRUCT(CompilationResult, void); +DEFINE_C_API_STRUCT(Library, void); +DEFINE_C_API_STRUCT(LibraryCompilationResult, void); +DEFINE_C_API_STRUCT(LibrarySupport, void); +DEFINE_C_API_STRUCT(CompilationOptions, void); +DEFINE_C_API_STRUCT(OptimizerConfig, void); #undef DEFINE_C_API_STRUCT @@ -32,7 +41,14 @@ DEFINE_C_API_STRUCT(CompilationResult, void); bool funcname(storage s) { return s.ptr == NULL; } DEFINE_NULL_PTR_CHECKER(compilerEngineIsNull, CompilerEngine); +DEFINE_NULL_PTR_CHECKER(compilationContextIsNull, CompilationContext); DEFINE_NULL_PTR_CHECKER(compilationResultIsNull, CompilationResult); +DEFINE_NULL_PTR_CHECKER(libraryIsNull, Library); +DEFINE_NULL_PTR_CHECKER(libraryCompilationResultIsNull, + LibraryCompilationResult); +DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport); +DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions); +DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig); #undef DEFINE_NULL_PTR_CHECKER @@ -42,9 +58,33 @@ DEFINE_NULL_PTR_CHECKER(compilationResultIsNull, CompilationResult); /// ********** CompilationTarget CAPI ****************************************** -enum CompilationTarget { ROUND_TRIP, OTHER }; +enum CompilationTarget { + ROUND_TRIP, + FHE, + TFHE, + CONCRETE, + CONCRETEWITHLOOPS, + BCONCRETE, + STD, + LLVM, + LLVM_IR, + OPTIMIZED_LLVM_IR, + LIBRARY +}; typedef enum CompilationTarget CompilationTarget; +/// ********** CompilationOptions CAPI ***************************************** + +MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate(); + +MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreateDefault(); + +/// ********** OptimizerConfig CAPI ******************************************** + +MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreate(); + +MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreateDefault(); + /// ********** CompilerEngine CAPI ********************************************* MLIR_CAPI_EXPORTED CompilerEngine compilerEngineCreate(); @@ -54,6 +94,10 @@ MLIR_CAPI_EXPORTED void compilerEngineDestroy(CompilerEngine engine); MLIR_CAPI_EXPORTED CompilationResult compilerEngineCompile( CompilerEngine engine, MlirStringRef module, CompilationTarget target); +MLIR_CAPI_EXPORTED void +compilerEngineCompileSetOptions(CompilerEngine engine, + CompilationOptions options); + /// ********** CompilationResult CAPI ****************************************** /// Get a string reference holding the textual representation of the compiled @@ -65,6 +109,37 @@ compilationResultGetModuleString(CompilationResult result); /// Free memory allocated for the module string. MLIR_CAPI_EXPORTED void compilationResultDestroyModuleString(MlirStringRef str); +MLIR_CAPI_EXPORTED void compilationResultDestroy(CompilationResult result); + +/// ********** Library CAPI **************************************************** + +MLIR_CAPI_EXPORTED Library libraryCreate(MlirStringRef outputDirPath, + MlirStringRef runtimeLibraryPath, + bool cleanUp); + +MLIR_CAPI_EXPORTED void libraryDestroy(Library lib); + +/// ********** LibraryCompilationResult CAPI *********************************** + +MLIR_CAPI_EXPORTED void +libraryCompilationResultDestroy(LibraryCompilationResult result); + +/// ********** LibrarySupport CAPI ********************************************* +MLIR_CAPI_EXPORTED LibrarySupport +librarySupportCreate(MlirStringRef outputDirPath, + MlirStringRef runtimeLibraryPath, bool generateSharedLib, + bool generateStaticLib, bool generateClientParameters, + bool generateCompilationFeedback, bool generateCppHeader); + +MLIR_CAPI_EXPORTED LibrarySupport librarySupportCreateDefault( + MlirStringRef outputDirPath, MlirStringRef runtimeLibraryPath) { + return librarySupportCreate(outputDirPath, runtimeLibraryPath, true, true, + true, true, true); +} + +MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile( + LibrarySupport support, MlirStringRef module, CompilationOptions options); + #ifdef __cplusplus } #endif diff --git a/compiler/include/concretelang/CAPI/Wrappers.h b/compiler/include/concretelang/CAPI/Wrappers.h index 541b6928f..116181fa3 100644 --- a/compiler/include/concretelang/CAPI/Wrappers.h +++ b/compiler/include/concretelang/CAPI/Wrappers.h @@ -8,10 +8,20 @@ #include "concretelang-c/Support/CompilerEngine.h" #include "concretelang/Support/CompilerEngine.h" +#include "concretelang/Support/LibrarySupport.h" #include "mlir/CAPI/Wrap.h" DEFINE_C_API_PTR_METHODS(CompilerEngine, mlir::concretelang::CompilerEngine) +DEFINE_C_API_PTR_METHODS(CompilationContext, + mlir::concretelang::CompilationContext) DEFINE_C_API_PTR_METHODS(CompilationResult, mlir::concretelang::CompilerEngine::CompilationResult) +DEFINE_C_API_PTR_METHODS(Library, mlir::concretelang::CompilerEngine::Library) +DEFINE_C_API_PTR_METHODS(LibraryCompilationResult, + mlir::concretelang::LibraryCompilationResult) +DEFINE_C_API_PTR_METHODS(LibrarySupport, mlir::concretelang::LibrarySupport) +DEFINE_C_API_PTR_METHODS(CompilationOptions, + mlir::concretelang::CompilationOptions) +DEFINE_C_API_PTR_METHODS(OptimizerConfig, mlir::concretelang::optimizer::Config) #endif diff --git a/compiler/lib/Bindings/Rust/Cargo.toml b/compiler/lib/Bindings/Rust/Cargo.toml index fed09d820..4fca9cf14 100644 --- a/compiler/lib/Bindings/Rust/Cargo.toml +++ b/compiler/lib/Bindings/Rust/Cargo.toml @@ -6,3 +6,5 @@ edition = "2021" [build-dependencies] bindgen = "0.60.1" +[dev-dependencies] +tempdir = "0.3.7" diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index 21d8cf715..79be5fe18 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -2,6 +2,9 @@ use crate::mlir::ffi::*; +#[derive(Debug)] +pub struct CompilationError(String); + /// Parse the MLIR code and returns it. /// /// The function parse the provided MLIR textual representation and returns it. It would fail with @@ -19,7 +22,7 @@ use crate::mlir::ffi::*; /// let result_str = round_trip(module_to_compile); /// ``` /// -pub fn round_trip(mlir_code: &str) -> String { +pub fn round_trip(mlir_code: &str) -> Result { unsafe { let engine = compilerEngineCreate(); let mlir_code_buffer = mlir_code.as_bytes(); @@ -31,6 +34,9 @@ pub fn round_trip(mlir_code: &str) -> String { }, CompilationTarget_ROUND_TRIP, ); + if compilationResultIsNull(compilation_result) { + return Err(CompilationError("roundtrip error".to_string())); + } let module_compiled = compilationResultGetModuleString(compilation_result); let result_str = String::from_utf8_lossy(std::slice::from_raw_parts( module_compiled.data as *const u8, @@ -39,12 +45,68 @@ pub fn round_trip(mlir_code: &str) -> String { .to_string(); compilationResultDestroyModuleString(module_compiled); compilerEngineDestroy(engine); - result_str + Ok(result_str) + } +} + +/// Support for compiling and executing libraries. +pub struct LibrarySupport { + support: crate::mlir::ffi::LibrarySupport, +} + +impl LibrarySupport { + /// LibrarySupport manages build files generated by the compiler under the `output_dir_path`. + /// + /// The compiled library needs to link to the runtime for proper execution. + pub fn new(output_dir_path: &str, runtime_library_path: &str) -> LibrarySupport { + unsafe { + let output_dir_path_buffer = output_dir_path.as_bytes(); + let runtime_library_path_buffer = runtime_library_path.as_bytes(); + LibrarySupport { + support: librarySupportCreateDefault( + MlirStringRef { + data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char, + length: output_dir_path_buffer.len() as size_t, + }, + MlirStringRef { + data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char, + length: runtime_library_path_buffer.len() as size_t, + }, + ), + } + } + } + + /// Compile an MLIR into a library. + pub fn compile( + &self, + mlir_code: &str, + options: Option, + ) -> Result { + unsafe { + let options = options.unwrap_or_else(|| compilationOptionsCreateDefault()); + let mlir_code_buffer = mlir_code.as_bytes(); + let result = librarySupportCompile( + self.support, + MlirStringRef { + data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char, + length: mlir_code_buffer.len() as size_t, + }, + options, + ); + if libraryCompilationResultIsNull(result) { + return Err(CompilationError("library compilation failed".to_string())); + } + Ok(result) + } } } #[cfg(test)] mod test { + use std::env; + use tempdir::TempDir; + use super::*; #[test] @@ -54,7 +116,7 @@ mod test { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> return %0 : !FHE.eint<5> }"; - let result_str = round_trip(module_to_compile); + let result_str = round_trip(module_to_compile).unwrap(); let expected_module = "module { func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> @@ -64,4 +126,47 @@ mod test { "; assert_eq!(expected_module, result_str); } + + #[test] + fn test_compiler_round_trip_invalid_mlir() { + let module_to_compile = "bla bla bla"; + let result_str = round_trip(module_to_compile); + assert!(matches!(result_str, Err(CompilationError(_)))); + } + + #[test] + fn test_compiler_compile_lib() { + unsafe { + let module_to_compile = " + func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> + return %0 : !FHE.eint<5> + }"; + let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") { + Ok(val) => val + "/lib/libConcretelangRuntime.so", + Err(_e) => "".to_string(), + }; + let temp_dir = TempDir::new("rust_test_compiler_compile_lib").unwrap(); + let support = LibrarySupport::new( + temp_dir.path().to_str().unwrap(), + runtime_library_path.as_str(), + ); + let lib = support.compile(module_to_compile, None).unwrap(); + assert!(!libraryCompilationResultIsNull(lib)); + libraryCompilationResultDestroy(lib); + // the sharedlib should be enough as a sign that the compilation worked + assert!(temp_dir.path().join("sharedlib.so").exists()); + } + } + + /// We want to make sure setting a pointer to null in rust passes the nullptr check in C/Cpp + #[test] + fn test_compiler_null_ptr_compatibility() { + unsafe { + let lib = Library { + ptr: std::ptr::null_mut(), + }; + assert!(libraryIsNull(lib)); + } + } } diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index a3f360372..57e702627 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -6,9 +6,58 @@ #include "concretelang-c/Support/CompilerEngine.h" #include "concretelang/CAPI/Wrappers.h" #include "concretelang/Support/CompilerEngine.h" +#include "concretelang/Support/Error.h" #include "mlir/IR/Diagnostics.h" +#include "llvm/Support/SourceMgr.h" -/// CompilerEngine CAPI +/// ********** CompilationOptions CAPI ***************************************** + +CompilationOptions +compilationOptionsCreate(MlirStringRef funcName, bool autoParallelize, + bool batchConcreteOps, bool dataflowParallelize, + bool emitGPUOps, bool loopParallelize, + bool optimizeConcrete, OptimizerConfig optimizerConfig, + bool verifyDiagnostics) { + std::string funcNameStr(funcName.data, funcName.length); + auto options = new mlir::concretelang::CompilationOptions(funcNameStr); + options->autoParallelize = autoParallelize; + options->batchConcreteOps = batchConcreteOps; + options->dataflowParallelize = dataflowParallelize; + options->emitGPUOps = emitGPUOps; + options->loopParallelize = loopParallelize; + options->optimizeConcrete = optimizeConcrete; + options->optimizerConfig = *unwrap(optimizerConfig); + options->verifyDiagnostics = verifyDiagnostics; + return wrap(options); +} + +CompilationOptions compilationOptionsCreateDefault() { + return wrap(new mlir::concretelang::CompilationOptions("main")); +} + +/// ********** OptimizerConfig CAPI ******************************************** + +OptimizerConfig optimizerConfigCreate(bool display, + double fallback_log_norm_woppbs, + double global_p_error, double p_error, + uint64_t security, bool strategy_v0, + bool use_gpu_constraints) { + auto config = new mlir::concretelang::optimizer::Config(); + config->display = display; + config->fallback_log_norm_woppbs = fallback_log_norm_woppbs; + config->global_p_error = global_p_error; + config->p_error = p_error; + config->security = security; + config->strategy_v0 = strategy_v0; + config->use_gpu_constraints = use_gpu_constraints; + return wrap(config); +} + +OptimizerConfig optimizerConfigCreateDefault() { + return wrap(new mlir::concretelang::optimizer::Config()); +} + +/// ********** CompilerEngine CAPI ********************************************* CompilerEngine compilerEngineCreate() { auto *engine = new mlir::concretelang::CompilerEngine( @@ -18,30 +67,61 @@ CompilerEngine compilerEngineCreate() { void compilerEngineDestroy(CompilerEngine engine) { delete unwrap(engine); } +/// Map C compilationTarget to Cpp +llvm::Expected +targetConvertToCppFromC(CompilationTarget target) { + switch (target) { + case ROUND_TRIP: + return mlir::concretelang::CompilerEngine::Target::ROUND_TRIP; + case FHE: + return mlir::concretelang::CompilerEngine::Target::FHE; + case TFHE: + return mlir::concretelang::CompilerEngine::Target::TFHE; + case CONCRETE: + return mlir::concretelang::CompilerEngine::Target::CONCRETE; + case CONCRETEWITHLOOPS: + return mlir::concretelang::CompilerEngine::Target::CONCRETEWITHLOOPS; + case BCONCRETE: + return mlir::concretelang::CompilerEngine::Target::BCONCRETE; + case STD: + return mlir::concretelang::CompilerEngine::Target::STD; + case LLVM: + return mlir::concretelang::CompilerEngine::Target::LLVM; + case LLVM_IR: + return mlir::concretelang::CompilerEngine::Target::LLVM_IR; + case OPTIMIZED_LLVM_IR: + return mlir::concretelang::CompilerEngine::Target::OPTIMIZED_LLVM_IR; + case LIBRARY: + return mlir::concretelang::CompilerEngine::Target::LIBRARY; + } + return mlir::concretelang::StreamStringError("invalid compilation target"); +} + CompilationResult compilerEngineCompile(CompilerEngine engine, MlirStringRef module, CompilationTarget target) { std::string module_str(module.data, module.length); - if (target == ROUND_TRIP) { - auto retOrError = unwrap(engine)->compile( - module_str, mlir::concretelang::CompilerEngine::Target::ROUND_TRIP); - if (!retOrError) { - // TODO: access the MlirContext - // mlir::emitError(mlir::UnknownLoc::get(unwrap(engine)) << "azeza"; - return wrap( - (mlir::concretelang::CompilerEngine::CompilationResult *)nullptr); - } - return wrap(new mlir::concretelang::CompilerEngine::CompilationResult( - std::move(retOrError.get()))); + auto targetCppOrError = targetConvertToCppFromC(target); + if (!targetCppOrError) { // invalid target + llvm::errs() << llvm::toString(targetCppOrError.takeError()); + return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL); } - return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)nullptr); + auto retOrError = unwrap(engine)->compile(module_str, targetCppOrError.get()); + if (!retOrError) { // compilation error + llvm::errs() << llvm::toString(retOrError.takeError()); + return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL); + } + return wrap(new mlir::concretelang::CompilerEngine::CompilationResult( + std::move(retOrError.get()))); } -/// CompilationResult CAPI -void compilationResultDestroy(CompilationResult result) { - delete unwrap(result); +void compilerEngineCompileSetOptions(CompilerEngine engine, + CompilationOptions options) { + unwrap(engine)->setCompilationOptions(*unwrap(options)); } +/// ********** CompilationResult CAPI ****************************************** + MlirStringRef compilationResultGetModuleString(CompilationResult result) { // print the module into a string std::string moduleString; @@ -56,3 +136,55 @@ MlirStringRef compilationResultGetModuleString(CompilationResult result) { void compilationResultDestroyModuleString(MlirStringRef str) { delete str.data; } + +void compilationResultDestroy(CompilationResult result) { + delete unwrap(result); +} + +/// ********** Library CAPI **************************************************** + +Library libraryCreate(MlirStringRef outputDirPath, + MlirStringRef runtimeLibraryPath, bool cleanUp) { + std::string outputDirPathStr(outputDirPath.data, outputDirPath.length); + std::string runtimeLibraryPathStr(runtimeLibraryPath.data, + runtimeLibraryPath.length); + return wrap(new mlir::concretelang::CompilerEngine::Library( + outputDirPathStr, runtimeLibraryPathStr, cleanUp)); +} + +void libraryDestroy(Library lib) { delete unwrap(lib); } + +/// ********** LibraryCompilationResult CAPI *********************************** + +void libraryCompilationResultDestroy(LibraryCompilationResult result) { + delete unwrap(result); +} + +/// ********** LibrarySupport CAPI ********************************************* + +LibrarySupport +librarySupportCreate(MlirStringRef outputDirPath, + MlirStringRef runtimeLibraryPath, bool generateSharedLib, + bool generateStaticLib, bool generateClientParameters, + bool generateCompilationFeedback, bool generateCppHeader) { + std::string outputDirPathStr(outputDirPath.data, outputDirPath.length); + std::string runtimeLibraryPathStr(runtimeLibraryPath.data, + runtimeLibraryPath.length); + return wrap(new mlir::concretelang::LibrarySupport( + outputDirPathStr, runtimeLibraryPathStr, generateSharedLib, + generateStaticLib, generateClientParameters, generateCompilationFeedback, + generateCppHeader)); +} + +LibraryCompilationResult librarySupportCompile(LibrarySupport support, + MlirStringRef module, + CompilationOptions options) { + std::string moduleStr(module.data, module.length); + auto retOrError = unwrap(support)->compile(moduleStr, *unwrap(options)); + if (!retOrError) { + llvm::errs() << llvm::toString(retOrError.takeError()); + return wrap((mlir::concretelang::LibraryCompilationResult *)NULL); + } + return wrap(new mlir::concretelang::LibraryCompilationResult( + *retOrError.get().get())); +} \ No newline at end of file