feat(rust): manage erros coming from compiler

C struct now contains an additonal char* pointer, which can be either
NULL in case there is no error, or a buffer containing the error
message. It's the responsability of destructor function to free that
memory.
This commit is contained in:
youben11
2022-11-28 11:11:26 +01:00
committed by Ayoub Benaissa
parent 15b4aac0a1
commit 16f3b0bbf6
4 changed files with 209 additions and 94 deletions

View File

@@ -12,15 +12,15 @@
extern "C" {
#endif
// TODO: add a char* to the struct that can return an error message in case of
// an error (where ptr would be null). Error messages can be returned using
// llvm::toString(error.takeError()) and allocating a buffer for the message and
// copy it. The buffer can later be freed during struct destruction.
/// Opaque type declarations. Refer to llvm-project/mlir/include/mlir-c/IR.h for
/// more info
/// Opaque type declarations. Inspired from
/// llvm-project/mlir/include/mlir-c/IR.h
///
/// Adds an error pointer to an allocated buffer holding the error message if
/// any.
#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
char *error; \
}; \
typedef struct name name

View File

@@ -9,33 +9,54 @@
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/LibrarySupport.h"
#include "mlir/CAPI/Wrap.h"
DEFINE_C_API_PTR_METHODS(CompilerEngine, mlir::concretelang::CompilerEngine)
DEFINE_C_API_PTR_METHODS(CompilationContext,
mlir::concretelang::CompilationContext)
DEFINE_C_API_PTR_METHODS(CompilationResult,
mlir::concretelang::CompilerEngine::CompilationResult)
DEFINE_C_API_PTR_METHODS(Library, mlir::concretelang::CompilerEngine::Library)
DEFINE_C_API_PTR_METHODS(LibraryCompilationResult,
mlir::concretelang::LibraryCompilationResult)
DEFINE_C_API_PTR_METHODS(LibrarySupport, mlir::concretelang::LibrarySupport)
DEFINE_C_API_PTR_METHODS(CompilationOptions,
mlir::concretelang::CompilationOptions)
DEFINE_C_API_PTR_METHODS(OptimizerConfig, mlir::concretelang::optimizer::Config)
DEFINE_C_API_PTR_METHODS(ServerLambda,
mlir::concretelang::serverlib::ServerLambda)
DEFINE_C_API_PTR_METHODS(ClientParameters,
mlir::concretelang::clientlib::ClientParameters)
DEFINE_C_API_PTR_METHODS(KeySet, mlir::concretelang::clientlib::KeySet)
DEFINE_C_API_PTR_METHODS(KeySetCache,
mlir::concretelang::clientlib::KeySetCache)
DEFINE_C_API_PTR_METHODS(EvaluationKeys,
mlir::concretelang::clientlib::EvaluationKeys)
DEFINE_C_API_PTR_METHODS(LambdaArgument, mlir::concretelang::LambdaArgument)
DEFINE_C_API_PTR_METHODS(PublicArguments,
mlir::concretelang::clientlib::PublicArguments)
DEFINE_C_API_PTR_METHODS(PublicResult,
mlir::concretelang::clientlib::PublicResult)
/// Add a mechanism to go from Cpp objects to C-struct, with the ability to
/// represent errors. Also the other way arround.
#define DEFINE_C_API_PTR_METHODS_WITH_ERROR(name, cpptype) \
static inline name wrap(cpptype *cpp) { return name{cpp, (char *)NULL}; } \
static inline name wrap(cpptype *cpp, std::string errorStr) { \
char *error = new char[errorStr.size()]; \
strcpy(error, errorStr.c_str()); \
return name{(cpptype *)NULL, error}; \
} \
static inline cpptype *unwrap(name c) { \
return static_cast<cpptype *>(c.ptr); \
} \
static inline char *getErrorPtr(name c) { return c.error; }
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilerEngine,
mlir::concretelang::CompilerEngine)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationContext,
mlir::concretelang::CompilationContext)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
CompilationResult, mlir::concretelang::CompilerEngine::CompilationResult)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(Library,
mlir::concretelang::CompilerEngine::Library)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
LibraryCompilationResult, mlir::concretelang::LibraryCompilationResult)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(LibrarySupport,
mlir::concretelang::LibrarySupport)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(CompilationOptions,
mlir::concretelang::CompilationOptions)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(OptimizerConfig,
mlir::concretelang::optimizer::Config)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(ServerLambda,
mlir::concretelang::serverlib::ServerLambda)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
ClientParameters, mlir::concretelang::clientlib::ClientParameters)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(KeySet,
mlir::concretelang::clientlib::KeySet)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(KeySetCache,
mlir::concretelang::clientlib::KeySetCache)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
EvaluationKeys, mlir::concretelang::clientlib::EvaluationKeys)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(LambdaArgument,
mlir::concretelang::LambdaArgument)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(
PublicArguments, mlir::concretelang::clientlib::PublicArguments)
DEFINE_C_API_PTR_METHODS_WITH_ERROR(PublicResult,
mlir::concretelang::clientlib::PublicResult)
#undef DEFINE_C_API_PTR_METHODS_WITH_ERROR
#endif

View File

@@ -1,12 +1,48 @@
//! Compiler module
use std::path::Path;
use std::{ffi::CStr, path::Path};
use crate::mlir::ffi::*;
#[derive(Debug)]
pub struct CompilerError(String);
/// Retreive buffer of the error message from a C struct.
trait CStructErrorMsg {
fn error_msg(&self) -> *mut i8;
}
/// All C struct can return a pointer to the allocated error message.
macro_rules! impl_CStructErrorMsg {
([$($t:ty),+]) => {
$(impl CStructErrorMsg for $t {
fn error_msg(&self) -> *mut i8 {
self.error
}
})*
}
}
impl_CStructErrorMsg! {[
crate::mlir::ffi::LibrarySupport,
CompilationResult,
LibraryCompilationResult,
ServerLambda,
ClientParameters,
PublicArguments,
PublicResult,
KeySet,
KeySetCache,
LambdaArgument
]}
/// Construct a rust error message from a buffer in the C struct.
fn get_error_msg_from_ctype<T: CStructErrorMsg>(c_struct: T) -> String {
unsafe {
let error_msg_cstr = CStr::from_ptr(c_struct.error_msg());
String::from(error_msg_cstr.to_str().unwrap())
}
}
/// Parse the MLIR code and returns it.
///
/// The function parse the provided MLIR textual representation and returns it. It would fail with
@@ -37,7 +73,12 @@ pub fn round_trip(mlir_code: &str) -> Result<String, CompilerError> {
CompilationTarget_ROUND_TRIP,
);
if compilationResultIsNull(compilation_result) {
return Err(CompilerError("roundtrip error".to_string()));
let error_msg = get_error_msg_from_ctype(compilation_result);
compilationResultDestroy(compilation_result);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
let module_compiled = compilationResultGetModuleString(compilation_result);
let result_str = String::from_utf8_lossy(std::slice::from_raw_parts(
@@ -86,7 +127,12 @@ impl LibrarySupport {
},
);
if librarySupportIsNull(support) {
return Err(CompilerError("failed creating library support".to_string()));
let error_msg = get_error_msg_from_ctype(support);
librarySupportDestroy(support);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Ok(LibrarySupport { support })
}
@@ -110,7 +156,12 @@ impl LibrarySupport {
options,
);
if libraryCompilationResultIsNull(result) {
return Err(CompilerError("library compilation failed".to_string()));
let error_msg = get_error_msg_from_ctype(result);
libraryCompilationResultDestroy(result);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Ok(result)
}
@@ -126,7 +177,12 @@ impl LibrarySupport {
unsafe {
let server = librarySupportLoadServerLambda(self.support, result);
if serverLambdaIsNull(server) {
return Err(CompilerError("loading server lambda failed".to_string()));
let error_msg = get_error_msg_from_ctype(server);
serverLambdaDestroy(server);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Ok(server)
}
@@ -142,9 +198,12 @@ impl LibrarySupport {
unsafe {
let params = librarySupportLoadClientParameters(self.support, result);
if clientParametersIsNull(params) {
return Err(CompilerError(
"loading client parameters failed".to_string(),
));
let error_msg = get_error_msg_from_ctype(params);
clientParametersDestroy(params);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Ok(params)
}
@@ -160,7 +219,12 @@ impl LibrarySupport {
unsafe {
let result = librarySupportServerCall(self.support, server_lambda, args, eval_keys);
if publicResultIsNull(result) {
return Err(CompilerError("failed calling server lambda".to_string()));
let error_msg = get_error_msg_from_ctype(result);
publicResultDestroy(result);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Ok(result)
}
@@ -201,9 +265,12 @@ impl ClientSupport {
length: cache_path_buffer.len() as size_t,
});
if keySetCacheIsNull(cache) {
return Err(CompilerError(
"failed creating keyset cache from path".to_string(),
));
let error_msg = get_error_msg_from_ctype(cache);
keySetCacheDestroy(cache);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Some(cache)
}
@@ -240,7 +307,12 @@ impl ClientSupport {
),
};
if keySetIsNull(key_set) {
return Err(CompilerError("getting keyset failed".to_string()));
let error_msg = get_error_msg_from_ctype(key_set);
keySetDestroy(key_set);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Ok(key_set)
}
@@ -260,7 +332,12 @@ impl ClientSupport {
key_set,
);
if publicArgumentsIsNull(public_args) {
return Err(CompilerError("encryption failed".to_string()));
let error_msg = get_error_msg_from_ctype(public_args);
publicArgumentsDestroy(public_args);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Ok(public_args)
}
@@ -274,7 +351,12 @@ impl ClientSupport {
unsafe {
let arg = publicResultDecrypt(result, key_set);
if lambdaArgumentIsNull(arg) {
return Err(CompilerError("decryption failed".to_string()));
let error_msg = get_error_msg_from_ctype(arg);
lambdaArgumentDestroy(arg);
return Err(CompilerError(format!(
"Error in compiler (check logs for more info): {}",
error_msg
)));
}
Ok(arg)
}
@@ -310,7 +392,9 @@ mod test {
fn test_compiler_round_trip_invalid_mlir() {
let module_to_compile = "bla bla bla";
let result_str = round_trip(module_to_compile);
assert!(matches!(result_str, Err(CompilerError(_))));
assert!(
matches!(result_str, Err(CompilerError(err)) if err == "Error in compiler (check logs for more info): Could not parse source\n")
);
}
#[test]
@@ -345,6 +429,7 @@ mod test {
unsafe {
let lib = Library {
ptr: std::ptr::null_mut(),
error: std::ptr::null_mut(),
};
assert!(libraryIsNull(lib));
}

View File

@@ -11,6 +11,14 @@
#include "mlir/IR/Diagnostics.h"
#include "llvm/Support/SourceMgr.h"
#define C_STRUCT_CLEANER(c_struct) \
auto *cpp = unwrap(c_struct); \
if (cpp != NULL) \
delete cpp; \
char *error = getErrorPtr(c_struct); \
if (error != NULL) \
delete[] error;
/// ********** CompilationOptions CAPI *****************************************
CompilationOptions
@@ -66,11 +74,11 @@ CompilerEngine compilerEngineCreate() {
return wrap(engine);
}
void compilerEngineDestroy(CompilerEngine engine) { delete unwrap(engine); }
void compilerEngineDestroy(CompilerEngine engine){C_STRUCT_CLEANER(engine)}
/// Map C compilationTarget to Cpp
llvm::Expected<mlir::concretelang::CompilerEngine::Target>
targetConvertToCppFromC(CompilationTarget target) {
llvm::Expected<mlir::concretelang::CompilerEngine::
Target> targetConvertToCppFromC(CompilationTarget target) {
switch (target) {
case ROUND_TRIP:
return mlir::concretelang::CompilerEngine::Target::ROUND_TRIP;
@@ -104,13 +112,13 @@ CompilationResult compilerEngineCompile(CompilerEngine engine,
std::string module_str(module.data, module.length);
auto targetCppOrError = targetConvertToCppFromC(target);
if (!targetCppOrError) { // invalid target
llvm::errs() << llvm::toString(targetCppOrError.takeError());
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL);
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL,
llvm::toString(targetCppOrError.takeError()));
}
auto retOrError = unwrap(engine)->compile(module_str, targetCppOrError.get());
if (!retOrError) { // compilation error
llvm::errs() << llvm::toString(retOrError.takeError());
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL);
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL,
llvm::toString(retOrError.takeError()));
}
return wrap(new mlir::concretelang::CompilerEngine::CompilationResult(
std::move(retOrError.get())));
@@ -138,9 +146,8 @@ void compilationResultDestroyModuleString(MlirStringRef str) {
delete str.data;
}
void compilationResultDestroy(CompilationResult result) {
delete unwrap(result);
}
void compilationResultDestroy(CompilationResult result){
C_STRUCT_CLEANER(result)}
/// ********** Library CAPI ****************************************************
@@ -153,21 +160,22 @@ Library libraryCreate(MlirStringRef outputDirPath,
outputDirPathStr, runtimeLibraryPathStr, cleanUp));
}
void libraryDestroy(Library lib) { delete unwrap(lib); }
void libraryDestroy(Library lib) { C_STRUCT_CLEANER(lib) }
/// ********** LibraryCompilationResult CAPI ***********************************
void libraryCompilationResultDestroy(LibraryCompilationResult result) {
delete unwrap(result);
}
void libraryCompilationResultDestroy(LibraryCompilationResult result){
C_STRUCT_CLEANER(result)}
/// ********** LibrarySupport CAPI *********************************************
LibrarySupport
librarySupportCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool generateSharedLib,
bool generateStaticLib, bool generateClientParameters,
bool generateCompilationFeedback, bool generateCppHeader) {
librarySupportCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath,
bool generateSharedLib, bool generateStaticLib,
bool generateClientParameters,
bool generateCompilationFeedback,
bool generateCppHeader) {
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
runtimeLibraryPath.length);
@@ -183,8 +191,8 @@ LibraryCompilationResult librarySupportCompile(LibrarySupport support,
std::string moduleStr(module.data, module.length);
auto retOrError = unwrap(support)->compile(moduleStr, *unwrap(options));
if (!retOrError) {
llvm::errs() << llvm::toString(retOrError.takeError());
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL);
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL,
llvm::toString(retOrError.takeError()));
}
return wrap(new mlir::concretelang::LibraryCompilationResult(
*retOrError.get().release()));
@@ -194,8 +202,8 @@ ServerLambda librarySupportLoadServerLambda(LibrarySupport support,
LibraryCompilationResult result) {
auto serverLambdaOrError = unwrap(support)->loadServerLambda(*unwrap(result));
if (!serverLambdaOrError) {
llvm::errs() << llvm::toString(serverLambdaOrError.takeError());
return wrap((mlir::concretelang::serverlib::ServerLambda *)NULL);
return wrap((mlir::concretelang::serverlib::ServerLambda *)NULL,
llvm::toString(serverLambdaOrError.takeError()));
}
return wrap(new mlir::concretelang::serverlib::ServerLambda(
serverLambdaOrError.get()));
@@ -206,8 +214,8 @@ librarySupportLoadClientParameters(LibrarySupport support,
LibraryCompilationResult result) {
auto paramsOrError = unwrap(support)->loadClientParameters(*unwrap(result));
if (!paramsOrError) {
llvm::errs() << llvm::toString(paramsOrError.takeError());
return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL);
return wrap((mlir::concretelang::clientlib::ClientParameters *)NULL,
llvm::toString(paramsOrError.takeError()));
}
return wrap(
new mlir::concretelang::clientlib::ClientParameters(paramsOrError.get()));
@@ -220,21 +228,21 @@ PublicResult librarySupportServerCall(LibrarySupport support,
auto resultOrError = unwrap(support)->serverCall(
*unwrap(server_lambda), *unwrap(args), *unwrap(evalKeys));
if (!resultOrError) {
llvm::errs() << llvm::toString(resultOrError.takeError());
return wrap((mlir::concretelang::clientlib::PublicResult *)NULL);
return wrap((mlir::concretelang::clientlib::PublicResult *)NULL,
llvm::toString(resultOrError.takeError()));
}
return wrap(resultOrError.get().release());
}
void librarySupportDestroy(LibrarySupport support) { delete unwrap(support); }
void librarySupportDestroy(LibrarySupport support) { C_STRUCT_CLEANER(support) }
/// ********** ServerLamda CAPI ************************************************
void serverLambdaDestroy(ServerLambda server) { delete unwrap(server); }
void serverLambdaDestroy(ServerLambda server) { C_STRUCT_CLEANER(server) }
/// ********** ClientParameters CAPI *******************************************
void clientParametersDestroy(ClientParameters params) { delete unwrap(params); }
void clientParametersDestroy(ClientParameters params){C_STRUCT_CLEANER(params)}
/// ********** KeySet CAPI *****************************************************
@@ -243,8 +251,8 @@ KeySet keySetGenerate(ClientParameters params, uint64_t seed_msb,
auto keySet = mlir::concretelang::clientlib::KeySet::generate(
*unwrap(params), seed_msb, seed_lsb);
if (keySet.has_error()) {
llvm::errs() << keySet.error().mesg;
return wrap((mlir::concretelang::clientlib::KeySet *)NULL);
return wrap((mlir::concretelang::clientlib::KeySet *)NULL,
keySet.error().mesg);
}
return wrap(keySet.value().release());
}
@@ -254,7 +262,7 @@ EvaluationKeys keySetGetEvaluationKeys(KeySet keySet) {
unwrap(keySet)->evaluationKeys()));
}
void keySetDestroy(KeySet keySet) { delete unwrap(keySet); }
void keySetDestroy(KeySet keySet){C_STRUCT_CLEANER(keySet)}
/// ********** KeySetCache CAPI ************************************************
@@ -269,18 +277,20 @@ KeySet keySetCacheLoadOrGenerateKeySet(KeySetCache cache,
auto keySetOrError =
unwrap(cache)->generate(*unwrap(params), seed_msb, seed_lsb);
if (keySetOrError.has_error()) {
llvm::errs() << keySetOrError.error().mesg;
return wrap((mlir::concretelang::clientlib::KeySet *)NULL);
return wrap((mlir::concretelang::clientlib::KeySet *)NULL,
keySetOrError.error().mesg);
}
return wrap(keySetOrError.value().release());
}
void keySetCacheDestroy(KeySetCache keySetCache) { delete unwrap(keySetCache); }
void keySetCacheDestroy(KeySetCache keySetCache) {
C_STRUCT_CLEANER(keySetCache)
}
/// ********** EvaluationKeys CAPI *********************************************
void evaluationKeysDestroy(EvaluationKeys evaluationKeys) {
delete unwrap(evaluationKeys);
C_STRUCT_CLEANER(evaluationKeys);
}
/// ********** LambdaArgument CAPI *********************************************
@@ -334,21 +344,20 @@ PublicArguments lambdaArgumentEncrypt(const LambdaArgument *lambdaArgs,
mlir::concretelang::LambdaSupport<int, int>::exportArguments(
*unwrap(params), *unwrap(keySet), args);
if (!publicArgsOrError) {
llvm::errs() << llvm::toString(publicArgsOrError.takeError());
return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL);
return wrap((mlir::concretelang::clientlib::PublicArguments *)NULL,
llvm::toString(publicArgsOrError.takeError()));
}
return wrap(publicArgsOrError.get().release());
}
void lambdaArgumentDestroy(LambdaArgument lambdaArg) {
delete unwrap(lambdaArg);
C_STRUCT_CLEANER(lambdaArg)
}
/// ********** PublicArguments CAPI ********************************************
void publicArgumentsDestroy(PublicArguments publicArgs) {
delete unwrap(publicArgs);
}
void publicArgumentsDestroy(PublicArguments publicArgs){
C_STRUCT_CLEANER(publicArgs)}
/// ********** PublicResult CAPI ***********************************************
@@ -358,12 +367,12 @@ LambdaArgument publicResultDecrypt(PublicResult publicResult, KeySet keySet) {
std::unique_ptr<mlir::concretelang::LambdaArgument>>(
*unwrap(keySet), *unwrap(publicResult));
if (!lambdaArgOrError) {
llvm::errs() << llvm::toString(lambdaArgOrError.takeError());
return wrap((mlir::concretelang::LambdaArgument *)NULL);
return wrap((mlir::concretelang::LambdaArgument *)NULL,
llvm::toString(lambdaArgOrError.takeError()));
}
return wrap(lambdaArgOrError.get().release());
}
void publicResultDestroy(PublicResult publicResult) {
delete unwrap(publicResult);
C_STRUCT_CLEANER(publicResult)
}