From 16f3b0bbf6e2c64c06cf23b6a5e349caabd85e7b Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 28 Nov 2022 11:11:26 +0100 Subject: [PATCH] feat(rust): manage erros coming from compiler C struct now contains an additonal char* pointer, which can be either NULL in case there is no error, or a buffer containing the error message. It's the responsability of destructor function to free that memory. --- .../concretelang-c/Support/CompilerEngine.h | 12 +- compiler/include/concretelang/CAPI/Wrappers.h | 75 +++++++---- compiler/lib/Bindings/Rust/src/compiler.rs | 117 +++++++++++++++--- compiler/lib/CAPI/Support/CompilerEngine.cpp | 99 ++++++++------- 4 files changed, 209 insertions(+), 94 deletions(-) diff --git a/compiler/include/concretelang-c/Support/CompilerEngine.h b/compiler/include/concretelang-c/Support/CompilerEngine.h index bb91b7ec5..acab455dc 100644 --- a/compiler/include/concretelang-c/Support/CompilerEngine.h +++ b/compiler/include/concretelang-c/Support/CompilerEngine.h @@ -12,15 +12,15 @@ extern "C" { #endif -// TODO: add a char* to the struct that can return an error message in case of -// an error (where ptr would be null). Error messages can be returned using -// llvm::toString(error.takeError()) and allocating a buffer for the message and -// copy it. The buffer can later be freed during struct destruction. -/// Opaque type declarations. Refer to llvm-project/mlir/include/mlir-c/IR.h for -/// more info +/// Opaque type declarations. Inspired from +/// llvm-project/mlir/include/mlir-c/IR.h +/// +/// Adds an error pointer to an allocated buffer holding the error message if +/// any. #define DEFINE_C_API_STRUCT(name, storage) \ struct name { \ storage *ptr; \ + char *error; \ }; \ typedef struct name name diff --git a/compiler/include/concretelang/CAPI/Wrappers.h b/compiler/include/concretelang/CAPI/Wrappers.h index 6b15e1b7d..dfd0cc12d 100644 --- a/compiler/include/concretelang/CAPI/Wrappers.h +++ b/compiler/include/concretelang/CAPI/Wrappers.h @@ -9,33 +9,54 @@ #include "concretelang-c/Support/CompilerEngine.h" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/LibrarySupport.h" -#include "mlir/CAPI/Wrap.h" -DEFINE_C_API_PTR_METHODS(CompilerEngine, mlir::concretelang::CompilerEngine) -DEFINE_C_API_PTR_METHODS(CompilationContext, - mlir::concretelang::CompilationContext) -DEFINE_C_API_PTR_METHODS(CompilationResult, - mlir::concretelang::CompilerEngine::CompilationResult) -DEFINE_C_API_PTR_METHODS(Library, mlir::concretelang::CompilerEngine::Library) -DEFINE_C_API_PTR_METHODS(LibraryCompilationResult, - mlir::concretelang::LibraryCompilationResult) -DEFINE_C_API_PTR_METHODS(LibrarySupport, mlir::concretelang::LibrarySupport) -DEFINE_C_API_PTR_METHODS(CompilationOptions, - mlir::concretelang::CompilationOptions) -DEFINE_C_API_PTR_METHODS(OptimizerConfig, mlir::concretelang::optimizer::Config) -DEFINE_C_API_PTR_METHODS(ServerLambda, - mlir::concretelang::serverlib::ServerLambda) -DEFINE_C_API_PTR_METHODS(ClientParameters, - mlir::concretelang::clientlib::ClientParameters) -DEFINE_C_API_PTR_METHODS(KeySet, mlir::concretelang::clientlib::KeySet) -DEFINE_C_API_PTR_METHODS(KeySetCache, - mlir::concretelang::clientlib::KeySetCache) -DEFINE_C_API_PTR_METHODS(EvaluationKeys, - mlir::concretelang::clientlib::EvaluationKeys) -DEFINE_C_API_PTR_METHODS(LambdaArgument, mlir::concretelang::LambdaArgument) -DEFINE_C_API_PTR_METHODS(PublicArguments, - mlir::concretelang::clientlib::PublicArguments) -DEFINE_C_API_PTR_METHODS(PublicResult, - mlir::concretelang::clientlib::PublicResult) +/// Add a mechanism to go from Cpp objects to C-struct, with the ability to +/// represent errors. Also the other way arround. +#define DEFINE_C_API_PTR_METHODS_WITH_ERROR(name, cpptype) \ + static inline name wrap(cpptype *cpp) { return name{cpp, (char *)NULL}; } \ + static inline name wrap(cpptype *cpp, std::string errorStr) { \ + char *error = new char[errorStr.size()]; \ + strcpy(error, errorStr.c_str()); \ + return name{(cpptype *)NULL, error}; \ + } \ + static inline cpptype *unwrap(name c) { \ + return static_cast(c.ptr); \ + } \ + static inline char *getErrorPtr(name c) { return c.error; } + +DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilerEngine, + mlir::concretelang::CompilerEngine) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationContext, + mlir::concretelang::CompilationContext) +DEFINE_C_API_PTR_METHODS_WITH_ERROR( + CompilationResult, mlir::concretelang::CompilerEngine::CompilationResult) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(Library, + mlir::concretelang::CompilerEngine::Library) +DEFINE_C_API_PTR_METHODS_WITH_ERROR( + LibraryCompilationResult, mlir::concretelang::LibraryCompilationResult) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(LibrarySupport, + mlir::concretelang::LibrarySupport) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationOptions, + mlir::concretelang::CompilationOptions) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(OptimizerConfig, + mlir::concretelang::optimizer::Config) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(ServerLambda, + mlir::concretelang::serverlib::ServerLambda) +DEFINE_C_API_PTR_METHODS_WITH_ERROR( + ClientParameters, mlir::concretelang::clientlib::ClientParameters) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(KeySet, + mlir::concretelang::clientlib::KeySet) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(KeySetCache, + mlir::concretelang::clientlib::KeySetCache) +DEFINE_C_API_PTR_METHODS_WITH_ERROR( + EvaluationKeys, mlir::concretelang::clientlib::EvaluationKeys) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(LambdaArgument, + mlir::concretelang::LambdaArgument) +DEFINE_C_API_PTR_METHODS_WITH_ERROR( + PublicArguments, mlir::concretelang::clientlib::PublicArguments) +DEFINE_C_API_PTR_METHODS_WITH_ERROR(PublicResult, + mlir::concretelang::clientlib::PublicResult) + +#undef DEFINE_C_API_PTR_METHODS_WITH_ERROR #endif diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs index 82b5a6fb9..5ce184c34 100644 --- a/compiler/lib/Bindings/Rust/src/compiler.rs +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -1,12 +1,48 @@ //! Compiler module -use std::path::Path; +use std::{ffi::CStr, path::Path}; use crate::mlir::ffi::*; #[derive(Debug)] pub struct CompilerError(String); +/// Retreive buffer of the error message from a C struct. +trait CStructErrorMsg { + fn error_msg(&self) -> *mut i8; +} + +/// All C struct can return a pointer to the allocated error message. +macro_rules! impl_CStructErrorMsg { + ([$($t:ty),+]) => { + $(impl CStructErrorMsg for $t { + fn error_msg(&self) -> *mut i8 { + self.error + } + })* + } +} +impl_CStructErrorMsg! {[ + crate::mlir::ffi::LibrarySupport, + CompilationResult, + LibraryCompilationResult, + ServerLambda, + ClientParameters, + PublicArguments, + PublicResult, + KeySet, + KeySetCache, + LambdaArgument +]} + +/// Construct a rust error message from a buffer in the C struct. +fn get_error_msg_from_ctype(c_struct: T) -> String { + unsafe { + let error_msg_cstr = CStr::from_ptr(c_struct.error_msg()); + String::from(error_msg_cstr.to_str().unwrap()) + } +} + /// Parse the MLIR code and returns it. /// /// The function parse the provided MLIR textual representation and returns it. It would fail with @@ -37,7 +73,12 @@ pub fn round_trip(mlir_code: &str) -> Result { CompilationTarget_ROUND_TRIP, ); if compilationResultIsNull(compilation_result) { - return Err(CompilerError("roundtrip error".to_string())); + let error_msg = get_error_msg_from_ctype(compilation_result); + compilationResultDestroy(compilation_result); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } let module_compiled = compilationResultGetModuleString(compilation_result); let result_str = String::from_utf8_lossy(std::slice::from_raw_parts( @@ -86,7 +127,12 @@ impl LibrarySupport { }, ); if librarySupportIsNull(support) { - return Err(CompilerError("failed creating library support".to_string())); + let error_msg = get_error_msg_from_ctype(support); + librarySupportDestroy(support); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Ok(LibrarySupport { support }) } @@ -110,7 +156,12 @@ impl LibrarySupport { options, ); if libraryCompilationResultIsNull(result) { - return Err(CompilerError("library compilation failed".to_string())); + let error_msg = get_error_msg_from_ctype(result); + libraryCompilationResultDestroy(result); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Ok(result) } @@ -126,7 +177,12 @@ impl LibrarySupport { unsafe { let server = librarySupportLoadServerLambda(self.support, result); if serverLambdaIsNull(server) { - return Err(CompilerError("loading server lambda failed".to_string())); + let error_msg = get_error_msg_from_ctype(server); + serverLambdaDestroy(server); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Ok(server) } @@ -142,9 +198,12 @@ impl LibrarySupport { unsafe { let params = librarySupportLoadClientParameters(self.support, result); if clientParametersIsNull(params) { - return Err(CompilerError( - "loading client parameters failed".to_string(), - )); + let error_msg = get_error_msg_from_ctype(params); + clientParametersDestroy(params); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Ok(params) } @@ -160,7 +219,12 @@ impl LibrarySupport { unsafe { let result = librarySupportServerCall(self.support, server_lambda, args, eval_keys); if publicResultIsNull(result) { - return Err(CompilerError("failed calling server lambda".to_string())); + let error_msg = get_error_msg_from_ctype(result); + publicResultDestroy(result); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Ok(result) } @@ -201,9 +265,12 @@ impl ClientSupport { length: cache_path_buffer.len() as size_t, }); if keySetCacheIsNull(cache) { - return Err(CompilerError( - "failed creating keyset cache from path".to_string(), - )); + let error_msg = get_error_msg_from_ctype(cache); + keySetCacheDestroy(cache); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Some(cache) } @@ -240,7 +307,12 @@ impl ClientSupport { ), }; if keySetIsNull(key_set) { - return Err(CompilerError("getting keyset failed".to_string())); + let error_msg = get_error_msg_from_ctype(key_set); + keySetDestroy(key_set); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Ok(key_set) } @@ -260,7 +332,12 @@ impl ClientSupport { key_set, ); if publicArgumentsIsNull(public_args) { - return Err(CompilerError("encryption failed".to_string())); + let error_msg = get_error_msg_from_ctype(public_args); + publicArgumentsDestroy(public_args); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Ok(public_args) } @@ -274,7 +351,12 @@ impl ClientSupport { unsafe { let arg = publicResultDecrypt(result, key_set); if lambdaArgumentIsNull(arg) { - return Err(CompilerError("decryption failed".to_string())); + let error_msg = get_error_msg_from_ctype(arg); + lambdaArgumentDestroy(arg); + return Err(CompilerError(format!( + "Error in compiler (check logs for more info): {}", + error_msg + ))); } Ok(arg) } @@ -310,7 +392,9 @@ mod test { fn test_compiler_round_trip_invalid_mlir() { let module_to_compile = "bla bla bla"; let result_str = round_trip(module_to_compile); - assert!(matches!(result_str, Err(CompilerError(_)))); + assert!( + matches!(result_str, Err(CompilerError(err)) if err == "Error in compiler (check logs for more info): Could not parse source\n") + ); } #[test] @@ -345,6 +429,7 @@ mod test { unsafe { let lib = Library { ptr: std::ptr::null_mut(), + error: std::ptr::null_mut(), }; assert!(libraryIsNull(lib)); } diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 15d54c136..324eddbd8 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -11,6 +11,14 @@ #include "mlir/IR/Diagnostics.h" #include "llvm/Support/SourceMgr.h" +#define C_STRUCT_CLEANER(c_struct) \ + auto *cpp = unwrap(c_struct); \ + if (cpp != NULL) \ + delete cpp; \ + char *error = getErrorPtr(c_struct); \ + if (error != NULL) \ + delete[] error; + /// ********** CompilationOptions CAPI ***************************************** CompilationOptions @@ -66,11 +74,11 @@ CompilerEngine compilerEngineCreate() { return wrap(engine); } -void compilerEngineDestroy(CompilerEngine engine) { delete unwrap(engine); } +void compilerEngineDestroy(CompilerEngine engine){C_STRUCT_CLEANER(engine)} /// Map C compilationTarget to Cpp -llvm::Expected -targetConvertToCppFromC(CompilationTarget target) { +llvm::Expected targetConvertToCppFromC(CompilationTarget target) { switch (target) { case ROUND_TRIP: return mlir::concretelang::CompilerEngine::Target::ROUND_TRIP; @@ -104,13 +112,13 @@ CompilationResult compilerEngineCompile(CompilerEngine engine, std::string module_str(module.data, module.length); auto targetCppOrError = targetConvertToCppFromC(target); if (!targetCppOrError) { // invalid target - llvm::errs() << llvm::toString(targetCppOrError.takeError()); - return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL); + return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL, + llvm::toString(targetCppOrError.takeError())); } auto retOrError = unwrap(engine)->compile(module_str, targetCppOrError.get()); if (!retOrError) { // compilation error - llvm::errs() << llvm::toString(retOrError.takeError()); - return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL); + return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL, + llvm::toString(retOrError.takeError())); } return wrap(new mlir::concretelang::CompilerEngine::CompilationResult( std::move(retOrError.get()))); @@ -138,9 +146,8 @@ void compilationResultDestroyModuleString(MlirStringRef str) { delete str.data; } -void compilationResultDestroy(CompilationResult result) { - delete unwrap(result); -} +void compilationResultDestroy(CompilationResult result){ + C_STRUCT_CLEANER(result)} /// ********** Library CAPI **************************************************** @@ -153,21 +160,22 @@ Library libraryCreate(MlirStringRef outputDirPath, outputDirPathStr, runtimeLibraryPathStr, cleanUp)); } -void libraryDestroy(Library lib) { delete unwrap(lib); } +void libraryDestroy(Library lib) { C_STRUCT_CLEANER(lib) } /// ********** LibraryCompilationResult CAPI *********************************** -void libraryCompilationResultDestroy(LibraryCompilationResult result) { - delete unwrap(result); -} +void libraryCompilationResultDestroy(LibraryCompilationResult result){ + C_STRUCT_CLEANER(result)} /// ********** LibrarySupport CAPI ********************************************* LibrarySupport -librarySupportCreate(MlirStringRef outputDirPath, - MlirStringRef runtimeLibraryPath, bool generateSharedLib, - bool generateStaticLib, bool generateClientParameters, - bool generateCompilationFeedback, bool generateCppHeader) { + librarySupportCreate(MlirStringRef outputDirPath, + MlirStringRef runtimeLibraryPath, + bool generateSharedLib, bool generateStaticLib, + bool generateClientParameters, + bool generateCompilationFeedback, + bool generateCppHeader) { std::string outputDirPathStr(outputDirPath.data, outputDirPath.length); std::string runtimeLibraryPathStr(runtimeLibraryPath.data, runtimeLibraryPath.length); @@ -183,8 +191,8 @@ LibraryCompilationResult librarySupportCompile(LibrarySupport support, std::string moduleStr(module.data, module.length); auto retOrError = unwrap(support)->compile(moduleStr, *unwrap(options)); if (!retOrError) { - llvm::errs() << llvm::toString(retOrError.takeError()); - return wrap((mlir::concretelang::LibraryCompilationResult *)NULL); + return wrap((mlir::concretelang::LibraryCompilationResult *)NULL, + llvm::toString(retOrError.takeError())); } return wrap(new mlir::concretelang::LibraryCompilationResult( *retOrError.get().release())); @@ -194,8 +202,8 @@ ServerLambda librarySupportLoadServerLambda(LibrarySupport support, LibraryCompilationResult result) { auto serverLambdaOrError = unwrap(support)->loadServerLambda(*unwrap(result)); if (!serverLambdaOrError) { - llvm::errs() << llvm::toString(serverLambdaOrError.takeError()); - return wrap((mlir::concretelang::serverlib::ServerLambda *)NULL); + return wrap((mlir::concretelang::serverlib::ServerLambda *)NULL, + llvm::toString(serverLambdaOrError.takeError())); } return wrap(new mlir::concretelang::serverlib::ServerLambda( serverLambdaOrError.get())); @@ -206,8 +214,8 @@ librarySupportLoadClientParameters(LibrarySupport support, LibraryCompilationResult result) { auto paramsOrError = unwrap(support)->loadClientParameters(*unwrap(result)); if (!paramsOrError) { - llvm::errs() << llvm::toString(paramsOrError.takeError()); - return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL); + return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL, + llvm::toString(paramsOrError.takeError())); } return wrap( new mlir::concretelang::clientlib::ClientParameters(paramsOrError.get())); @@ -220,21 +228,21 @@ PublicResult librarySupportServerCall(LibrarySupport support, auto resultOrError = unwrap(support)->serverCall( *unwrap(server_lambda), *unwrap(args), *unwrap(evalKeys)); if (!resultOrError) { - llvm::errs() << llvm::toString(resultOrError.takeError()); - return wrap((mlir::concretelang::clientlib::PublicResult *)NULL); + return wrap((mlir::concretelang::clientlib::PublicResult *)NULL, + llvm::toString(resultOrError.takeError())); } return wrap(resultOrError.get().release()); } -void librarySupportDestroy(LibrarySupport support) { delete unwrap(support); } +void librarySupportDestroy(LibrarySupport support) { C_STRUCT_CLEANER(support) } /// ********** ServerLamda CAPI ************************************************ -void serverLambdaDestroy(ServerLambda server) { delete unwrap(server); } +void serverLambdaDestroy(ServerLambda server) { C_STRUCT_CLEANER(server) } /// ********** ClientParameters CAPI ******************************************* -void clientParametersDestroy(ClientParameters params) { delete unwrap(params); } +void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)} /// ********** KeySet CAPI ***************************************************** @@ -243,8 +251,8 @@ KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb, auto keySet = mlir::concretelang::clientlib::KeySet::generate( *unwrap(params), seed_msb, seed_lsb); if (keySet.has_error()) { - llvm::errs() << keySet.error().mesg; - return wrap((mlir::concretelang::clientlib::KeySet *)NULL); + return wrap((mlir::concretelang::clientlib::KeySet *)NULL, + keySet.error().mesg); } return wrap(keySet.value().release()); } @@ -254,7 +262,7 @@ EvaluationKeys keySetGetEvaluationKeys(KeySet keySet) { unwrap(keySet)->evaluationKeys())); } -void keySetDestroy(KeySet keySet) { delete unwrap(keySet); } +void keySetDestroy(KeySet keySet){C_STRUCT_CLEANER(keySet)} /// ********** KeySetCache CAPI ************************************************ @@ -269,18 +277,20 @@ KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache, auto keySetOrError = unwrap(cache)->generate(*unwrap(params), seed_msb, seed_lsb); if (keySetOrError.has_error()) { - llvm::errs() << keySetOrError.error().mesg; - return wrap((mlir::concretelang::clientlib::KeySet *)NULL); + return wrap((mlir::concretelang::clientlib::KeySet *)NULL, + keySetOrError.error().mesg); } return wrap(keySetOrError.value().release()); } -void keySetCacheDestroy(KeySetCache keySetCache) { delete unwrap(keySetCache); } +void keySetCacheDestroy(KeySetCache keySetCache) { + C_STRUCT_CLEANER(keySetCache) +} /// ********** EvaluationKeys CAPI ********************************************* void evaluationKeysDestroy(EvaluationKeys evaluationKeys) { - delete unwrap(evaluationKeys); + C_STRUCT_CLEANER(evaluationKeys); } /// ********** LambdaArgument CAPI ********************************************* @@ -334,21 +344,20 @@ PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs, mlir::concretelang::LambdaSupport::exportArguments( *unwrap(params), *unwrap(keySet), args); if (!publicArgsOrError) { - llvm::errs() << llvm::toString(publicArgsOrError.takeError()); - return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL); + return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL, + llvm::toString(publicArgsOrError.takeError())); } return wrap(publicArgsOrError.get().release()); } void lambdaArgumentDestroy(LambdaArgument lambdaArg) { - delete unwrap(lambdaArg); + C_STRUCT_CLEANER(lambdaArg) } /// ********** PublicArguments CAPI ******************************************** -void publicArgumentsDestroy(PublicArguments publicArgs) { - delete unwrap(publicArgs); -} +void publicArgumentsDestroy(PublicArguments publicArgs){ + C_STRUCT_CLEANER(publicArgs)} /// ********** PublicResult CAPI *********************************************** @@ -358,12 +367,12 @@ LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) { std::unique_ptr>( *unwrap(keySet), *unwrap(publicResult)); if (!lambdaArgOrError) { - llvm::errs() << llvm::toString(lambdaArgOrError.takeError()); - return wrap((mlir::concretelang::LambdaArgument *)NULL); + return wrap((mlir::concretelang::LambdaArgument *)NULL, + llvm::toString(lambdaArgOrError.takeError())); } return wrap(lambdaArgOrError.get().release()); } void publicResultDestroy(PublicResult publicResult) { - delete unwrap(publicResult); + C_STRUCT_CLEANER(publicResult) }