From 96c958bd064cf6f572bd681d863b7ed40a634db1 Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 21 Nov 2022 15:57:53 +0100 Subject: [PATCH] feat(rust): add compiler module with round_trip feature --- compiler/lib/Bindings/Rust/api.h | 1 + compiler/lib/Bindings/Rust/build.rs | 123 +++++++++++++++++++-- compiler/lib/Bindings/Rust/src/compiler.rs | 67 +++++++++++ compiler/lib/Bindings/Rust/src/lib.rs | 1 + 4 files changed, 183 insertions(+), 9 deletions(-) create mode 100644 compiler/lib/Bindings/Rust/src/compiler.rs diff --git a/compiler/lib/Bindings/Rust/api.h b/compiler/lib/Bindings/Rust/api.h index e3e2ac2ca..542ff5bfb 100644 --- a/compiler/lib/Bindings/Rust/api.h +++ b/compiler/lib/Bindings/Rust/api.h @@ -5,6 +5,7 @@ #include #include +#include #include #include #include diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs index 3a632724d..2337e89fa 100644 --- a/compiler/lib/Bindings/Rust/build.rs +++ b/compiler/lib/Bindings/Rust/build.rs @@ -187,6 +187,92 @@ const MLIR_STATIC_LIBS: [&str; 179] = [ "MLIRArithmeticToLLVM", ]; +const LLVM_STATIC_LIBS: [&str; 51] = [ + "LLVMAggressiveInstCombine", + "LLVMAnalysis", + "LLVMAsmParser", + "LLVMAsmPrinter", + "LLVMBinaryFormat", + "LLVMBitReader", + "LLVMBitstreamReader", + "LLVMBitWriter", + "LLVMCFGuard", + "LLVMCodeGen", + "LLVMCore", + "LLVMCoroutines", + "LLVMDebugInfoCodeView", + "LLVMDebugInfoDWARF", + "LLVMDebugInfoMSF", + "LLVMDebugInfoPDB", + "LLVMDemangle", + "LLVMExecutionEngine", + "LLVMFrontendOpenMP", + "LLVMGlobalISel", + "LLVMInstCombine", + "LLVMInstrumentation", + "LLVMipo", + "LLVMIRReader", + "LLVMJITLink", + "LLVMLinker", + "LLVMMC", + "LLVMMCDisassembler", + "LLVMMCParser", + "LLVMObjCARCOpts", + "LLVMObject", + "LLVMOrcJIT", + "LLVMOrcShared", + "LLVMOrcTargetProcess", + "LLVMPasses", + "LLVMProfileData", + "LLVMRemarks", + "LLVMRuntimeDyld", + "LLVMScalarOpts", + "LLVMSelectionDAG", + "LLVMSupport", + "LLVMSymbolize", + "LLVMTableGen", + "LLVMTableGenGlobalISel", + "LLVMTarget", + "LLVMTextAPI", + "LLVMTransformUtils", + "LLVMVectorize", + "LLVMX86CodeGen", + "LLVMX86Desc", + "LLVMX86Info", +]; + +const CONCRETE_COMPILER_LIBS: [&str; 29] = [ + "RTDialect", + "RTDialectTransforms", + "ConcretelangSupport", + "BConcreteToCAPI", + "ConcretelangConversion", + "ConcretelangTransforms", + "FHETensorOpsToLinalg", + "ConcretelangServerLib", + "ConcreteToBConcrete", + "CONCRETELANGCAPIFHE", + "TFHEGlobalParametrization", + "ConcretelangClientLib", + "ConcretelangBConcreteTransforms", + "CONCRETELANGCAPISupport", + "FHELinalgDialect", + "ConcretelangInterfaces", + "TFHEDialect", + "CONCRETELANGCAPIFHELINALG", + "FHELinalgDialectTransforms", + "FHEDialect", + "TFHEToConcrete", + "FHEToTFHE", + "ConcreteDialectTransforms", + "BConcreteDialect", + "concrete_optimizer", + "LinalgExtras", + "FHEDialectAnalysis", + "ConcreteDialect", + "RTDialectAnalysis", +]; + fn main() { if let Err(error) = run() { eprintln!("{}", error); @@ -195,13 +281,21 @@ fn main() { } fn run() -> Result<(), Box> { + let root = std::fs::canonicalize("../../../../")?; // library paths let build_dir = get_build_dir(); let lib_dir = build_dir.join("lib"); + // compiler build libs println!("cargo:rustc-link-search={}", lib_dir.to_str().unwrap()); + // concrete optimizer lib + println!( + "cargo:rustc-link-search={}", + root.join("compiler/concrete-optimizer/target/release") + .to_str() + .unwrap() + ); // include paths - let root = std::fs::canonicalize("../../../../")?; let include_paths = [ // compiler build build_dir.join("tools/concretelang/include/"), @@ -220,21 +314,30 @@ fn run() -> Result<(), Box> { ]; // linking + // concrete-compiler libs + for concrete_compiler_lib in CONCRETE_COMPILER_LIBS { + println!("cargo:rustc-link-lib=static={}", concrete_compiler_lib); + } + // concrete compiler runtime + println!("cargo:rustc-link-lib=ConcretelangRuntime"); + // concrete optimizer + // `-bundle` serve to not have multiple definition issues + println!("cargo:rustc-link-lib=static:-bundle=concrete_optimizer_cpp"); + // mlir libs for mlir_static_lib in MLIR_STATIC_LIBS { println!("cargo:rustc-link-lib=static={}", mlir_static_lib); } - println!("cargo:rustc-link-lib=static=LLVMSupport"); - println!("cargo:rustc-link-lib=static=LLVMCore"); + // llvm libs + for llvm_static_lib in LLVM_STATIC_LIBS { + println!("cargo:rustc-link-lib=static={}", llvm_static_lib); + } // required by llvm println!("cargo:rustc-link-lib=tinfo"); if let Some(name) = get_system_libcpp() { println!("cargo:rustc-link-lib={}", name); } - // concrete-compiler libs - println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHE"); - println!("cargo:rustc-link-lib=static=FHEDialect"); - println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHELINALG"); - println!("cargo:rustc-link-lib=static=FHELinalgDialect"); + // zlib + println!("cargo:rustc-link-lib=z"); println!("cargo:rerun-if-changed=api.h"); bindgen::builder() @@ -272,7 +375,9 @@ fn get_build_dir() -> PathBuf { .join("..") .join("..") .join("..") - .join("build"), + .join("build") + .canonicalize() + .unwrap(), }; return build_dir; } diff --git a/compiler/lib/Bindings/Rust/src/compiler.rs b/compiler/lib/Bindings/Rust/src/compiler.rs new file mode 100644 index 000000000..21d8cf715 --- /dev/null +++ b/compiler/lib/Bindings/Rust/src/compiler.rs @@ -0,0 +1,67 @@ +//! Compiler module + +use crate::mlir::ffi::*; + +/// Parse the MLIR code and returns it. +/// +/// The function parse the provided MLIR textual representation and returns it. It would fail with +/// an error message to stderr reporting what's bad with the parsed IR. +/// +/// # Examples +/// ``` +/// use concrete_compiler_rust::compiler::*; +/// +/// let module_to_compile = " +/// func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { +/// %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> +/// return %0 : !FHE.eint<5> +/// }"; +/// let result_str = round_trip(module_to_compile); +/// ``` +/// +pub fn round_trip(mlir_code: &str) -> String { + unsafe { + let engine = compilerEngineCreate(); + let mlir_code_buffer = mlir_code.as_bytes(); + let compilation_result = compilerEngineCompile( + engine, + MlirStringRef { + data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char, + length: mlir_code_buffer.len() as size_t, + }, + CompilationTarget_ROUND_TRIP, + ); + let module_compiled = compilationResultGetModuleString(compilation_result); + let result_str = String::from_utf8_lossy(std::slice::from_raw_parts( + module_compiled.data as *const u8, + module_compiled.length as usize, + )) + .to_string(); + compilationResultDestroyModuleString(module_compiled); + compilerEngineDestroy(engine); + result_str + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_compiler_round_trip() { + let module_to_compile = " + func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> + return %0 : !FHE.eint<5> + }"; + let result_str = round_trip(module_to_compile); + let expected_module = "module { + func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> + return %0 : !FHE.eint<5> + } +} +"; + assert_eq!(expected_module, result_str); + } +} diff --git a/compiler/lib/Bindings/Rust/src/lib.rs b/compiler/lib/Bindings/Rust/src/lib.rs index f345633db..f1844d431 100644 --- a/compiler/lib/Bindings/Rust/src/lib.rs +++ b/compiler/lib/Bindings/Rust/src/lib.rs @@ -1,3 +1,4 @@ +pub mod compiler; pub mod fhe; pub mod fhelinalg; pub mod mlir;