feat(rust): compile mlir into library

CAPI covering a wider API of the Support library.
Better error handling. Could also be improved by returning an error
message back from C to rust (left TODO).
This commit is contained in:
youben11
2022-11-24 10:42:18 +01:00
committed by Ayoub Benaissa
parent 8a557368f1
commit b00115f4ae
5 changed files with 344 additions and 20 deletions

View File

@@ -12,6 +12,10 @@
extern "C" {
#endif
// TODO: add a char* to the struct that can return an error message in case of
// an error (where ptr would be null). Error messages can be returned using
// llvm::toString(error.takeError()) and allocating a buffer for the message and
// copy it. The buffer can later be freed during struct destruction.
/// Opaque type declarations. Refer to llvm-project/mlir/include/mlir-c/IR.h for
/// more info
#define DEFINE_C_API_STRUCT(name, storage) \
@@ -23,6 +27,11 @@ extern "C" {
DEFINE_C_API_STRUCT(CompilerEngine, void);
DEFINE_C_API_STRUCT(CompilationContext, void);
DEFINE_C_API_STRUCT(CompilationResult, void);
DEFINE_C_API_STRUCT(Library, void);
DEFINE_C_API_STRUCT(LibraryCompilationResult, void);
DEFINE_C_API_STRUCT(LibrarySupport, void);
DEFINE_C_API_STRUCT(CompilationOptions, void);
DEFINE_C_API_STRUCT(OptimizerConfig, void);
#undef DEFINE_C_API_STRUCT
@@ -32,7 +41,14 @@ DEFINE_C_API_STRUCT(CompilationResult, void);
bool funcname(storage s) { return s.ptr == NULL; }
DEFINE_NULL_PTR_CHECKER(compilerEngineIsNull, CompilerEngine);
DEFINE_NULL_PTR_CHECKER(compilationContextIsNull, CompilationContext);
DEFINE_NULL_PTR_CHECKER(compilationResultIsNull, CompilationResult);
DEFINE_NULL_PTR_CHECKER(libraryIsNull, Library);
DEFINE_NULL_PTR_CHECKER(libraryCompilationResultIsNull,
LibraryCompilationResult);
DEFINE_NULL_PTR_CHECKER(librarySupportIsNull, LibrarySupport);
DEFINE_NULL_PTR_CHECKER(compilationOptionsIsNull, CompilationOptions);
DEFINE_NULL_PTR_CHECKER(optimizerConfigIsNull, OptimizerConfig);
#undef DEFINE_NULL_PTR_CHECKER
@@ -42,9 +58,33 @@ DEFINE_NULL_PTR_CHECKER(compilationResultIsNull, CompilationResult);
/// ********** CompilationTarget CAPI ******************************************
enum CompilationTarget { ROUND_TRIP, OTHER };
enum CompilationTarget {
ROUND_TRIP,
FHE,
TFHE,
CONCRETE,
CONCRETEWITHLOOPS,
BCONCRETE,
STD,
LLVM,
LLVM_IR,
OPTIMIZED_LLVM_IR,
LIBRARY
};
typedef enum CompilationTarget CompilationTarget;
/// ********** CompilationOptions CAPI *****************************************
MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreate();
MLIR_CAPI_EXPORTED CompilationOptions compilationOptionsCreateDefault();
/// ********** OptimizerConfig CAPI ********************************************
MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreate();
MLIR_CAPI_EXPORTED OptimizerConfig optimizerConfigCreateDefault();
/// ********** CompilerEngine CAPI *********************************************
MLIR_CAPI_EXPORTED CompilerEngine compilerEngineCreate();
@@ -54,6 +94,10 @@ MLIR_CAPI_EXPORTED void compilerEngineDestroy(CompilerEngine engine);
MLIR_CAPI_EXPORTED CompilationResult compilerEngineCompile(
CompilerEngine engine, MlirStringRef module, CompilationTarget target);
MLIR_CAPI_EXPORTED void
compilerEngineCompileSetOptions(CompilerEngine engine,
CompilationOptions options);
/// ********** CompilationResult CAPI ******************************************
/// Get a string reference holding the textual representation of the compiled
@@ -65,6 +109,37 @@ compilationResultGetModuleString(CompilationResult result);
/// Free memory allocated for the module string.
MLIR_CAPI_EXPORTED void compilationResultDestroyModuleString(MlirStringRef str);
MLIR_CAPI_EXPORTED void compilationResultDestroy(CompilationResult result);
/// ********** Library CAPI ****************************************************
MLIR_CAPI_EXPORTED Library libraryCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath,
bool cleanUp);
MLIR_CAPI_EXPORTED void libraryDestroy(Library lib);
/// ********** LibraryCompilationResult CAPI ***********************************
MLIR_CAPI_EXPORTED void
libraryCompilationResultDestroy(LibraryCompilationResult result);
/// ********** LibrarySupport CAPI *********************************************
MLIR_CAPI_EXPORTED LibrarySupport
librarySupportCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool generateSharedLib,
bool generateStaticLib, bool generateClientParameters,
bool generateCompilationFeedback, bool generateCppHeader);
MLIR_CAPI_EXPORTED LibrarySupport librarySupportCreateDefault(
MlirStringRef outputDirPath, MlirStringRef runtimeLibraryPath) {
return librarySupportCreate(outputDirPath, runtimeLibraryPath, true, true,
true, true, true);
}
MLIR_CAPI_EXPORTED LibraryCompilationResult librarySupportCompile(
LibrarySupport support, MlirStringRef module, CompilationOptions options);
#ifdef __cplusplus
}
#endif

View File

@@ -8,10 +8,20 @@
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/LibrarySupport.h"
#include "mlir/CAPI/Wrap.h"
DEFINE_C_API_PTR_METHODS(CompilerEngine, mlir::concretelang::CompilerEngine)
DEFINE_C_API_PTR_METHODS(CompilationContext,
mlir::concretelang::CompilationContext)
DEFINE_C_API_PTR_METHODS(CompilationResult,
mlir::concretelang::CompilerEngine::CompilationResult)
DEFINE_C_API_PTR_METHODS(Library, mlir::concretelang::CompilerEngine::Library)
DEFINE_C_API_PTR_METHODS(LibraryCompilationResult,
mlir::concretelang::LibraryCompilationResult)
DEFINE_C_API_PTR_METHODS(LibrarySupport, mlir::concretelang::LibrarySupport)
DEFINE_C_API_PTR_METHODS(CompilationOptions,
mlir::concretelang::CompilationOptions)
DEFINE_C_API_PTR_METHODS(OptimizerConfig, mlir::concretelang::optimizer::Config)
#endif

View File

@@ -6,3 +6,5 @@ edition = "2021"
[build-dependencies]
bindgen = "0.60.1"
[dev-dependencies]
tempdir = "0.3.7"

View File

@@ -2,6 +2,9 @@
use crate::mlir::ffi::*;
#[derive(Debug)]
pub struct CompilationError(String);
/// Parse the MLIR code and returns it.
///
/// The function parse the provided MLIR textual representation and returns it. It would fail with
@@ -19,7 +22,7 @@ use crate::mlir::ffi::*;
/// let result_str = round_trip(module_to_compile);
/// ```
///
pub fn round_trip(mlir_code: &str) -> String {
pub fn round_trip(mlir_code: &str) -> Result<String, CompilationError> {
unsafe {
let engine = compilerEngineCreate();
let mlir_code_buffer = mlir_code.as_bytes();
@@ -31,6 +34,9 @@ pub fn round_trip(mlir_code: &str) -> String {
},
CompilationTarget_ROUND_TRIP,
);
if compilationResultIsNull(compilation_result) {
return Err(CompilationError("roundtrip error".to_string()));
}
let module_compiled = compilationResultGetModuleString(compilation_result);
let result_str = String::from_utf8_lossy(std::slice::from_raw_parts(
module_compiled.data as *const u8,
@@ -39,12 +45,68 @@ pub fn round_trip(mlir_code: &str) -> String {
.to_string();
compilationResultDestroyModuleString(module_compiled);
compilerEngineDestroy(engine);
result_str
Ok(result_str)
}
}
/// Support for compiling and executing libraries.
pub struct LibrarySupport {
support: crate::mlir::ffi::LibrarySupport,
}
impl LibrarySupport {
/// LibrarySupport manages build files generated by the compiler under the `output_dir_path`.
///
/// The compiled library needs to link to the runtime for proper execution.
pub fn new(output_dir_path: &str, runtime_library_path: &str) -> LibrarySupport {
unsafe {
let output_dir_path_buffer = output_dir_path.as_bytes();
let runtime_library_path_buffer = runtime_library_path.as_bytes();
LibrarySupport {
support: librarySupportCreateDefault(
MlirStringRef {
data: output_dir_path_buffer.as_ptr() as *const std::os::raw::c_char,
length: output_dir_path_buffer.len() as size_t,
},
MlirStringRef {
data: runtime_library_path_buffer.as_ptr() as *const std::os::raw::c_char,
length: runtime_library_path_buffer.len() as size_t,
},
),
}
}
}
/// Compile an MLIR into a library.
pub fn compile(
&self,
mlir_code: &str,
options: Option<CompilationOptions>,
) -> Result<LibraryCompilationResult, CompilationError> {
unsafe {
let options = options.unwrap_or_else(|| compilationOptionsCreateDefault());
let mlir_code_buffer = mlir_code.as_bytes();
let result = librarySupportCompile(
self.support,
MlirStringRef {
data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char,
length: mlir_code_buffer.len() as size_t,
},
options,
);
if libraryCompilationResultIsNull(result) {
return Err(CompilationError("library compilation failed".to_string()));
}
Ok(result)
}
}
}
#[cfg(test)]
mod test {
use std::env;
use tempdir::TempDir;
use super::*;
#[test]
@@ -54,7 +116,7 @@ mod test {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}";
let result_str = round_trip(module_to_compile);
let result_str = round_trip(module_to_compile).unwrap();
let expected_module = "module {
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
@@ -64,4 +126,47 @@ mod test {
";
assert_eq!(expected_module, result_str);
}
#[test]
fn test_compiler_round_trip_invalid_mlir() {
let module_to_compile = "bla bla bla";
let result_str = round_trip(module_to_compile);
assert!(matches!(result_str, Err(CompilationError(_))));
}
#[test]
fn test_compiler_compile_lib() {
unsafe {
let module_to_compile = "
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}";
let runtime_library_path = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
Ok(val) => val + "/lib/libConcretelangRuntime.so",
Err(_e) => "".to_string(),
};
let temp_dir = TempDir::new("rust_test_compiler_compile_lib").unwrap();
let support = LibrarySupport::new(
temp_dir.path().to_str().unwrap(),
runtime_library_path.as_str(),
);
let lib = support.compile(module_to_compile, None).unwrap();
assert!(!libraryCompilationResultIsNull(lib));
libraryCompilationResultDestroy(lib);
// the sharedlib should be enough as a sign that the compilation worked
assert!(temp_dir.path().join("sharedlib.so").exists());
}
}
/// We want to make sure setting a pointer to null in rust passes the nullptr check in C/Cpp
#[test]
fn test_compiler_null_ptr_compatibility() {
unsafe {
let lib = Library {
ptr: std::ptr::null_mut(),
};
assert!(libraryIsNull(lib));
}
}
}

View File

@@ -6,9 +6,58 @@
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/CAPI/Wrappers.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/Support/SourceMgr.h"
/// CompilerEngine CAPI
/// ********** CompilationOptions CAPI *****************************************
CompilationOptions
compilationOptionsCreate(MlirStringRef funcName, bool autoParallelize,
bool batchConcreteOps, bool dataflowParallelize,
bool emitGPUOps, bool loopParallelize,
bool optimizeConcrete, OptimizerConfig optimizerConfig,
bool verifyDiagnostics) {
std::string funcNameStr(funcName.data, funcName.length);
auto options = new mlir::concretelang::CompilationOptions(funcNameStr);
options->autoParallelize = autoParallelize;
options->batchConcreteOps = batchConcreteOps;
options->dataflowParallelize = dataflowParallelize;
options->emitGPUOps = emitGPUOps;
options->loopParallelize = loopParallelize;
options->optimizeConcrete = optimizeConcrete;
options->optimizerConfig = *unwrap(optimizerConfig);
options->verifyDiagnostics = verifyDiagnostics;
return wrap(options);
}
CompilationOptions compilationOptionsCreateDefault() {
return wrap(new mlir::concretelang::CompilationOptions("main"));
}
/// ********** OptimizerConfig CAPI ********************************************
OptimizerConfig optimizerConfigCreate(bool display,
double fallback_log_norm_woppbs,
double global_p_error, double p_error,
uint64_t security, bool strategy_v0,
bool use_gpu_constraints) {
auto config = new mlir::concretelang::optimizer::Config();
config->display = display;
config->fallback_log_norm_woppbs = fallback_log_norm_woppbs;
config->global_p_error = global_p_error;
config->p_error = p_error;
config->security = security;
config->strategy_v0 = strategy_v0;
config->use_gpu_constraints = use_gpu_constraints;
return wrap(config);
}
OptimizerConfig optimizerConfigCreateDefault() {
return wrap(new mlir::concretelang::optimizer::Config());
}
/// ********** CompilerEngine CAPI *********************************************
CompilerEngine compilerEngineCreate() {
auto *engine = new mlir::concretelang::CompilerEngine(
@@ -18,30 +67,61 @@ CompilerEngine compilerEngineCreate() {
void compilerEngineDestroy(CompilerEngine engine) { delete unwrap(engine); }
/// Map C compilationTarget to Cpp
llvm::Expected<mlir::concretelang::CompilerEngine::Target>
targetConvertToCppFromC(CompilationTarget target) {
switch (target) {
case ROUND_TRIP:
return mlir::concretelang::CompilerEngine::Target::ROUND_TRIP;
case FHE:
return mlir::concretelang::CompilerEngine::Target::FHE;
case TFHE:
return mlir::concretelang::CompilerEngine::Target::TFHE;
case CONCRETE:
return mlir::concretelang::CompilerEngine::Target::CONCRETE;
case CONCRETEWITHLOOPS:
return mlir::concretelang::CompilerEngine::Target::CONCRETEWITHLOOPS;
case BCONCRETE:
return mlir::concretelang::CompilerEngine::Target::BCONCRETE;
case STD:
return mlir::concretelang::CompilerEngine::Target::STD;
case LLVM:
return mlir::concretelang::CompilerEngine::Target::LLVM;
case LLVM_IR:
return mlir::concretelang::CompilerEngine::Target::LLVM_IR;
case OPTIMIZED_LLVM_IR:
return mlir::concretelang::CompilerEngine::Target::OPTIMIZED_LLVM_IR;
case LIBRARY:
return mlir::concretelang::CompilerEngine::Target::LIBRARY;
}
return mlir::concretelang::StreamStringError("invalid compilation target");
}
CompilationResult compilerEngineCompile(CompilerEngine engine,
MlirStringRef module,
CompilationTarget target) {
std::string module_str(module.data, module.length);
if (target == ROUND_TRIP) {
auto retOrError = unwrap(engine)->compile(
module_str, mlir::concretelang::CompilerEngine::Target::ROUND_TRIP);
if (!retOrError) {
// TODO: access the MlirContext
// mlir::emitError(mlir::UnknownLoc::get(unwrap(engine)) << "azeza";
return wrap(
(mlir::concretelang::CompilerEngine::CompilationResult *)nullptr);
}
return wrap(new mlir::concretelang::CompilerEngine::CompilationResult(
std::move(retOrError.get())));
auto targetCppOrError = targetConvertToCppFromC(target);
if (!targetCppOrError) { // invalid target
llvm::errs() << llvm::toString(targetCppOrError.takeError());
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL);
}
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)nullptr);
auto retOrError = unwrap(engine)->compile(module_str, targetCppOrError.get());
if (!retOrError) { // compilation error
llvm::errs() << llvm::toString(retOrError.takeError());
return wrap((mlir::concretelang::CompilerEngine::CompilationResult *)NULL);
}
return wrap(new mlir::concretelang::CompilerEngine::CompilationResult(
std::move(retOrError.get())));
}
/// CompilationResult CAPI
void compilationResultDestroy(CompilationResult result) {
delete unwrap(result);
void compilerEngineCompileSetOptions(CompilerEngine engine,
CompilationOptions options) {
unwrap(engine)->setCompilationOptions(*unwrap(options));
}
/// ********** CompilationResult CAPI ******************************************
MlirStringRef compilationResultGetModuleString(CompilationResult result) {
// print the module into a string
std::string moduleString;
@@ -56,3 +136,55 @@ MlirStringRef compilationResultGetModuleString(CompilationResult result) {
void compilationResultDestroyModuleString(MlirStringRef str) {
delete str.data;
}
void compilationResultDestroy(CompilationResult result) {
delete unwrap(result);
}
/// ********** Library CAPI ****************************************************
Library libraryCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool cleanUp) {
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
runtimeLibraryPath.length);
return wrap(new mlir::concretelang::CompilerEngine::Library(
outputDirPathStr, runtimeLibraryPathStr, cleanUp));
}
void libraryDestroy(Library lib) { delete unwrap(lib); }
/// ********** LibraryCompilationResult CAPI ***********************************
void libraryCompilationResultDestroy(LibraryCompilationResult result) {
delete unwrap(result);
}
/// ********** LibrarySupport CAPI *********************************************
LibrarySupport
librarySupportCreate(MlirStringRef outputDirPath,
MlirStringRef runtimeLibraryPath, bool generateSharedLib,
bool generateStaticLib, bool generateClientParameters,
bool generateCompilationFeedback, bool generateCppHeader) {
std::string outputDirPathStr(outputDirPath.data, outputDirPath.length);
std::string runtimeLibraryPathStr(runtimeLibraryPath.data,
runtimeLibraryPath.length);
return wrap(new mlir::concretelang::LibrarySupport(
outputDirPathStr, runtimeLibraryPathStr, generateSharedLib,
generateStaticLib, generateClientParameters, generateCompilationFeedback,
generateCppHeader));
}
LibraryCompilationResult librarySupportCompile(LibrarySupport support,
MlirStringRef module,
CompilationOptions options) {
std::string moduleStr(module.data, module.length);
auto retOrError = unwrap(support)->compile(moduleStr, *unwrap(options));
if (!retOrError) {
llvm::errs() << llvm::toString(retOrError.takeError());
return wrap((mlir::concretelang::LibraryCompilationResult *)NULL);
}
return wrap(new mlir::concretelang::LibraryCompilationResult(
*retOrError.get().get()));
}