feat(rust): add compiler module with round_trip feature

This commit is contained in:
youben11
2022-11-21 15:57:53 +01:00
committed by Ayoub Benaissa
parent 5661b758d7
commit 96c958bd06
4 changed files with 183 additions and 9 deletions

View File

@@ -5,6 +5,7 @@
#include <concretelang-c/Dialect/FHE.h>
#include <concretelang-c/Dialect/FHELinalg.h>
#include <concretelang-c/Support/CompilerEngine.h>
#include <mlir-c/AffineExpr.h>
#include <mlir-c/AffineMap.h>
#include <mlir-c/BuiltinAttributes.h>

View File

@@ -187,6 +187,92 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRArithmeticToLLVM",
];
const LLVM_STATIC_LIBS: [&str; 51] = [
"LLVMAggressiveInstCombine",
"LLVMAnalysis",
"LLVMAsmParser",
"LLVMAsmPrinter",
"LLVMBinaryFormat",
"LLVMBitReader",
"LLVMBitstreamReader",
"LLVMBitWriter",
"LLVMCFGuard",
"LLVMCodeGen",
"LLVMCore",
"LLVMCoroutines",
"LLVMDebugInfoCodeView",
"LLVMDebugInfoDWARF",
"LLVMDebugInfoMSF",
"LLVMDebugInfoPDB",
"LLVMDemangle",
"LLVMExecutionEngine",
"LLVMFrontendOpenMP",
"LLVMGlobalISel",
"LLVMInstCombine",
"LLVMInstrumentation",
"LLVMipo",
"LLVMIRReader",
"LLVMJITLink",
"LLVMLinker",
"LLVMMC",
"LLVMMCDisassembler",
"LLVMMCParser",
"LLVMObjCARCOpts",
"LLVMObject",
"LLVMOrcJIT",
"LLVMOrcShared",
"LLVMOrcTargetProcess",
"LLVMPasses",
"LLVMProfileData",
"LLVMRemarks",
"LLVMRuntimeDyld",
"LLVMScalarOpts",
"LLVMSelectionDAG",
"LLVMSupport",
"LLVMSymbolize",
"LLVMTableGen",
"LLVMTableGenGlobalISel",
"LLVMTarget",
"LLVMTextAPI",
"LLVMTransformUtils",
"LLVMVectorize",
"LLVMX86CodeGen",
"LLVMX86Desc",
"LLVMX86Info",
];
const CONCRETE_COMPILER_LIBS: [&str; 29] = [
"RTDialect",
"RTDialectTransforms",
"ConcretelangSupport",
"BConcreteToCAPI",
"ConcretelangConversion",
"ConcretelangTransforms",
"FHETensorOpsToLinalg",
"ConcretelangServerLib",
"ConcreteToBConcrete",
"CONCRETELANGCAPIFHE",
"TFHEGlobalParametrization",
"ConcretelangClientLib",
"ConcretelangBConcreteTransforms",
"CONCRETELANGCAPISupport",
"FHELinalgDialect",
"ConcretelangInterfaces",
"TFHEDialect",
"CONCRETELANGCAPIFHELINALG",
"FHELinalgDialectTransforms",
"FHEDialect",
"TFHEToConcrete",
"FHEToTFHE",
"ConcreteDialectTransforms",
"BConcreteDialect",
"concrete_optimizer",
"LinalgExtras",
"FHEDialectAnalysis",
"ConcreteDialect",
"RTDialectAnalysis",
];
fn main() {
if let Err(error) = run() {
eprintln!("{}", error);
@@ -195,13 +281,21 @@ fn main() {
}
fn run() -> Result<(), Box<dyn Error>> {
let root = std::fs::canonicalize("../../../../")?;
// library paths
let build_dir = get_build_dir();
let lib_dir = build_dir.join("lib");
// compiler build libs
println!("cargo:rustc-link-search={}", lib_dir.to_str().unwrap());
// concrete optimizer lib
println!(
"cargo:rustc-link-search={}",
root.join("compiler/concrete-optimizer/target/release")
.to_str()
.unwrap()
);
// include paths
let root = std::fs::canonicalize("../../../../")?;
let include_paths = [
// compiler build
build_dir.join("tools/concretelang/include/"),
@@ -220,21 +314,30 @@ fn run() -> Result<(), Box<dyn Error>> {
];
// linking
// concrete-compiler libs
for concrete_compiler_lib in CONCRETE_COMPILER_LIBS {
println!("cargo:rustc-link-lib=static={}", concrete_compiler_lib);
}
// concrete compiler runtime
println!("cargo:rustc-link-lib=ConcretelangRuntime");
// concrete optimizer
// `-bundle` serve to not have multiple definition issues
println!("cargo:rustc-link-lib=static:-bundle=concrete_optimizer_cpp");
// mlir libs
for mlir_static_lib in MLIR_STATIC_LIBS {
println!("cargo:rustc-link-lib=static={}", mlir_static_lib);
}
println!("cargo:rustc-link-lib=static=LLVMSupport");
println!("cargo:rustc-link-lib=static=LLVMCore");
// llvm libs
for llvm_static_lib in LLVM_STATIC_LIBS {
println!("cargo:rustc-link-lib=static={}", llvm_static_lib);
}
// required by llvm
println!("cargo:rustc-link-lib=tinfo");
if let Some(name) = get_system_libcpp() {
println!("cargo:rustc-link-lib={}", name);
}
// concrete-compiler libs
println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHE");
println!("cargo:rustc-link-lib=static=FHEDialect");
println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHELINALG");
println!("cargo:rustc-link-lib=static=FHELinalgDialect");
// zlib
println!("cargo:rustc-link-lib=z");
println!("cargo:rerun-if-changed=api.h");
bindgen::builder()
@@ -272,7 +375,9 @@ fn get_build_dir() -> PathBuf {
.join("..")
.join("..")
.join("..")
.join("build"),
.join("build")
.canonicalize()
.unwrap(),
};
return build_dir;
}

View File

@@ -0,0 +1,67 @@
//! Compiler module
use crate::mlir::ffi::*;
/// Parse the MLIR code and returns it.
///
/// The function parse the provided MLIR textual representation and returns it. It would fail with
/// an error message to stderr reporting what's bad with the parsed IR.
///
/// # Examples
/// ```
/// use concrete_compiler_rust::compiler::*;
///
/// let module_to_compile = "
/// func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
/// %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
/// return %0 : !FHE.eint<5>
/// }";
/// let result_str = round_trip(module_to_compile);
/// ```
///
pub fn round_trip(mlir_code: &str) -> String {
unsafe {
let engine = compilerEngineCreate();
let mlir_code_buffer = mlir_code.as_bytes();
let compilation_result = compilerEngineCompile(
engine,
MlirStringRef {
data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char,
length: mlir_code_buffer.len() as size_t,
},
CompilationTarget_ROUND_TRIP,
);
let module_compiled = compilationResultGetModuleString(compilation_result);
let result_str = String::from_utf8_lossy(std::slice::from_raw_parts(
module_compiled.data as *const u8,
module_compiled.length as usize,
))
.to_string();
compilationResultDestroyModuleString(module_compiled);
compilerEngineDestroy(engine);
result_str
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_compiler_round_trip() {
let module_to_compile = "
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}";
let result_str = round_trip(module_to_compile);
let expected_module = "module {
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}
}
";
assert_eq!(expected_module, result_str);
}
}

View File

@@ -1,3 +1,4 @@
pub mod compiler;
pub mod fhe;
pub mod fhelinalg;
pub mod mlir;