mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
feat(rust): add compiler module with round_trip feature
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <concretelang-c/Dialect/FHE.h>
|
||||
#include <concretelang-c/Dialect/FHELinalg.h>
|
||||
#include <concretelang-c/Support/CompilerEngine.h>
|
||||
#include <mlir-c/AffineExpr.h>
|
||||
#include <mlir-c/AffineMap.h>
|
||||
#include <mlir-c/BuiltinAttributes.h>
|
||||
|
||||
@@ -187,6 +187,92 @@ const MLIR_STATIC_LIBS: [&str; 179] = [
|
||||
"MLIRArithmeticToLLVM",
|
||||
];
|
||||
|
||||
const LLVM_STATIC_LIBS: [&str; 51] = [
|
||||
"LLVMAggressiveInstCombine",
|
||||
"LLVMAnalysis",
|
||||
"LLVMAsmParser",
|
||||
"LLVMAsmPrinter",
|
||||
"LLVMBinaryFormat",
|
||||
"LLVMBitReader",
|
||||
"LLVMBitstreamReader",
|
||||
"LLVMBitWriter",
|
||||
"LLVMCFGuard",
|
||||
"LLVMCodeGen",
|
||||
"LLVMCore",
|
||||
"LLVMCoroutines",
|
||||
"LLVMDebugInfoCodeView",
|
||||
"LLVMDebugInfoDWARF",
|
||||
"LLVMDebugInfoMSF",
|
||||
"LLVMDebugInfoPDB",
|
||||
"LLVMDemangle",
|
||||
"LLVMExecutionEngine",
|
||||
"LLVMFrontendOpenMP",
|
||||
"LLVMGlobalISel",
|
||||
"LLVMInstCombine",
|
||||
"LLVMInstrumentation",
|
||||
"LLVMipo",
|
||||
"LLVMIRReader",
|
||||
"LLVMJITLink",
|
||||
"LLVMLinker",
|
||||
"LLVMMC",
|
||||
"LLVMMCDisassembler",
|
||||
"LLVMMCParser",
|
||||
"LLVMObjCARCOpts",
|
||||
"LLVMObject",
|
||||
"LLVMOrcJIT",
|
||||
"LLVMOrcShared",
|
||||
"LLVMOrcTargetProcess",
|
||||
"LLVMPasses",
|
||||
"LLVMProfileData",
|
||||
"LLVMRemarks",
|
||||
"LLVMRuntimeDyld",
|
||||
"LLVMScalarOpts",
|
||||
"LLVMSelectionDAG",
|
||||
"LLVMSupport",
|
||||
"LLVMSymbolize",
|
||||
"LLVMTableGen",
|
||||
"LLVMTableGenGlobalISel",
|
||||
"LLVMTarget",
|
||||
"LLVMTextAPI",
|
||||
"LLVMTransformUtils",
|
||||
"LLVMVectorize",
|
||||
"LLVMX86CodeGen",
|
||||
"LLVMX86Desc",
|
||||
"LLVMX86Info",
|
||||
];
|
||||
|
||||
const CONCRETE_COMPILER_LIBS: [&str; 29] = [
|
||||
"RTDialect",
|
||||
"RTDialectTransforms",
|
||||
"ConcretelangSupport",
|
||||
"BConcreteToCAPI",
|
||||
"ConcretelangConversion",
|
||||
"ConcretelangTransforms",
|
||||
"FHETensorOpsToLinalg",
|
||||
"ConcretelangServerLib",
|
||||
"ConcreteToBConcrete",
|
||||
"CONCRETELANGCAPIFHE",
|
||||
"TFHEGlobalParametrization",
|
||||
"ConcretelangClientLib",
|
||||
"ConcretelangBConcreteTransforms",
|
||||
"CONCRETELANGCAPISupport",
|
||||
"FHELinalgDialect",
|
||||
"ConcretelangInterfaces",
|
||||
"TFHEDialect",
|
||||
"CONCRETELANGCAPIFHELINALG",
|
||||
"FHELinalgDialectTransforms",
|
||||
"FHEDialect",
|
||||
"TFHEToConcrete",
|
||||
"FHEToTFHE",
|
||||
"ConcreteDialectTransforms",
|
||||
"BConcreteDialect",
|
||||
"concrete_optimizer",
|
||||
"LinalgExtras",
|
||||
"FHEDialectAnalysis",
|
||||
"ConcreteDialect",
|
||||
"RTDialectAnalysis",
|
||||
];
|
||||
|
||||
fn main() {
|
||||
if let Err(error) = run() {
|
||||
eprintln!("{}", error);
|
||||
@@ -195,13 +281,21 @@ fn main() {
|
||||
}
|
||||
|
||||
fn run() -> Result<(), Box<dyn Error>> {
|
||||
let root = std::fs::canonicalize("../../../../")?;
|
||||
// library paths
|
||||
let build_dir = get_build_dir();
|
||||
let lib_dir = build_dir.join("lib");
|
||||
// compiler build libs
|
||||
println!("cargo:rustc-link-search={}", lib_dir.to_str().unwrap());
|
||||
// concrete optimizer lib
|
||||
println!(
|
||||
"cargo:rustc-link-search={}",
|
||||
root.join("compiler/concrete-optimizer/target/release")
|
||||
.to_str()
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
// include paths
|
||||
let root = std::fs::canonicalize("../../../../")?;
|
||||
let include_paths = [
|
||||
// compiler build
|
||||
build_dir.join("tools/concretelang/include/"),
|
||||
@@ -220,21 +314,30 @@ fn run() -> Result<(), Box<dyn Error>> {
|
||||
];
|
||||
|
||||
// linking
|
||||
// concrete-compiler libs
|
||||
for concrete_compiler_lib in CONCRETE_COMPILER_LIBS {
|
||||
println!("cargo:rustc-link-lib=static={}", concrete_compiler_lib);
|
||||
}
|
||||
// concrete compiler runtime
|
||||
println!("cargo:rustc-link-lib=ConcretelangRuntime");
|
||||
// concrete optimizer
|
||||
// `-bundle` serve to not have multiple definition issues
|
||||
println!("cargo:rustc-link-lib=static:-bundle=concrete_optimizer_cpp");
|
||||
// mlir libs
|
||||
for mlir_static_lib in MLIR_STATIC_LIBS {
|
||||
println!("cargo:rustc-link-lib=static={}", mlir_static_lib);
|
||||
}
|
||||
println!("cargo:rustc-link-lib=static=LLVMSupport");
|
||||
println!("cargo:rustc-link-lib=static=LLVMCore");
|
||||
// llvm libs
|
||||
for llvm_static_lib in LLVM_STATIC_LIBS {
|
||||
println!("cargo:rustc-link-lib=static={}", llvm_static_lib);
|
||||
}
|
||||
// required by llvm
|
||||
println!("cargo:rustc-link-lib=tinfo");
|
||||
if let Some(name) = get_system_libcpp() {
|
||||
println!("cargo:rustc-link-lib={}", name);
|
||||
}
|
||||
// concrete-compiler libs
|
||||
println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHE");
|
||||
println!("cargo:rustc-link-lib=static=FHEDialect");
|
||||
println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHELINALG");
|
||||
println!("cargo:rustc-link-lib=static=FHELinalgDialect");
|
||||
// zlib
|
||||
println!("cargo:rustc-link-lib=z");
|
||||
|
||||
println!("cargo:rerun-if-changed=api.h");
|
||||
bindgen::builder()
|
||||
@@ -272,7 +375,9 @@ fn get_build_dir() -> PathBuf {
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("..")
|
||||
.join("build"),
|
||||
.join("build")
|
||||
.canonicalize()
|
||||
.unwrap(),
|
||||
};
|
||||
return build_dir;
|
||||
}
|
||||
|
||||
67
compiler/lib/Bindings/Rust/src/compiler.rs
Normal file
67
compiler/lib/Bindings/Rust/src/compiler.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! Compiler module
|
||||
|
||||
use crate::mlir::ffi::*;
|
||||
|
||||
/// Parse the MLIR code and returns it.
|
||||
///
|
||||
/// The function parse the provided MLIR textual representation and returns it. It would fail with
|
||||
/// an error message to stderr reporting what's bad with the parsed IR.
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use concrete_compiler_rust::compiler::*;
|
||||
///
|
||||
/// let module_to_compile = "
|
||||
/// func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
/// %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
/// return %0 : !FHE.eint<5>
|
||||
/// }";
|
||||
/// let result_str = round_trip(module_to_compile);
|
||||
/// ```
|
||||
///
|
||||
pub fn round_trip(mlir_code: &str) -> String {
|
||||
unsafe {
|
||||
let engine = compilerEngineCreate();
|
||||
let mlir_code_buffer = mlir_code.as_bytes();
|
||||
let compilation_result = compilerEngineCompile(
|
||||
engine,
|
||||
MlirStringRef {
|
||||
data: mlir_code_buffer.as_ptr() as *const std::os::raw::c_char,
|
||||
length: mlir_code_buffer.len() as size_t,
|
||||
},
|
||||
CompilationTarget_ROUND_TRIP,
|
||||
);
|
||||
let module_compiled = compilationResultGetModuleString(compilation_result);
|
||||
let result_str = String::from_utf8_lossy(std::slice::from_raw_parts(
|
||||
module_compiled.data as *const u8,
|
||||
module_compiled.length as usize,
|
||||
))
|
||||
.to_string();
|
||||
compilationResultDestroyModuleString(module_compiled);
|
||||
compilerEngineDestroy(engine);
|
||||
result_str
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_compiler_round_trip() {
|
||||
let module_to_compile = "
|
||||
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
return %0 : !FHE.eint<5>
|
||||
}";
|
||||
let result_str = round_trip(module_to_compile);
|
||||
let expected_module = "module {
|
||||
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
|
||||
return %0 : !FHE.eint<5>
|
||||
}
|
||||
}
|
||||
";
|
||||
assert_eq!(expected_module, result_str);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod compiler;
|
||||
pub mod fhe;
|
||||
pub mod fhelinalg;
|
||||
pub mod mlir;
|
||||
|
||||
Reference in New Issue
Block a user