diff --git a/.gitignore b/.gitignore index e595d0aa4..e7fa8f484 100644 --- a/.gitignore +++ b/.gitignore @@ -51,3 +51,5 @@ _build/ compiler/tests/TestLib/out/ +compiler/lib/Bindings/Rust/target/ +compiler/lib/Bindings/Rust/Cargo.lock diff --git a/compiler/Makefile b/compiler/Makefile index 08b9234fb..51f025e3d 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -144,6 +144,13 @@ python-bindings: build-initialized cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules cmake --build $(BUILD_DIR) --target ConcretelangPythonModules +# MLIRCAPIRegistration is currently required for linking rust bindings +# This may fade if we can represent it somewhere else, mainly, we may be able to define a single +# lib that the rust bindings need to link to, while that lib contains all necessary libs +rust-bindings: build-initialized concretecompiler + cmake --build $(BUILD_DIR) --target MLIRCAPIRegistration + cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo build --release + clientlib: build-initialized cmake --build $(BUILD_DIR) --target ConcretelangClientLib @@ -198,6 +205,10 @@ run-python-tests: python-bindings concretecompiler test-compiler-file-output: concretecompiler pytest -vs tests/test_compiler_file_output +## rust-tests +run-rust-tests: rust-bindings + cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo test --release + ## end-to-end-tests build-end-to-end-tests: build-end-to-end-jit-test build-end-to-end-jit-fhe build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-fhelinalg build-end-to-end-jit-lambda diff --git a/compiler/lib/Bindings/Rust/Cargo.toml b/compiler/lib/Bindings/Rust/Cargo.toml new file mode 100644 index 000000000..4ced9e254 --- /dev/null +++ b/compiler/lib/Bindings/Rust/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "concrete_compiler_rust" +version = "0.1.0" +edition = "2021" + +[build-dependencies] +bindgen = "0.60.1" + +[dev-dependencies] +stdio-override = "0.1.3" +tempfile = "3.3.0" diff --git a/compiler/lib/Bindings/Rust/README.md b/compiler/lib/Bindings/Rust/README.md new file mode 100644 index 000000000..f3aaa58f2 --- /dev/null +++ b/compiler/lib/Bindings/Rust/README.md @@ -0,0 +1,9 @@ +# Rust Bindings + +A Rust library providing an API to the Concrete Compiler. + +### Build + +```bash +$ cargo build --release +``` diff --git a/compiler/lib/Bindings/Rust/api.h b/compiler/lib/Bindings/Rust/api.h new file mode 100644 index 000000000..a4d0bd5a9 --- /dev/null +++ b/compiler/lib/Bindings/Rust/api.h @@ -0,0 +1,32 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs new file mode 100644 index 000000000..fb8948385 --- /dev/null +++ b/compiler/lib/Bindings/Rust/build.rs @@ -0,0 +1,273 @@ +extern crate bindgen; + +use std::env; +use std::error::Error; +use std::path::{Path, PathBuf}; +use std::process::exit; + +const MLIR_STATIC_LIBS: [&str; 179] = [ + "MLIRMemRefDialect", + "MLIRVectorToSPIRV", + "MLIRControlFlowInterfaces", + "MLIRLinalgToStandard", + "MLIRAnalysis", + "MLIRSPIRVDeserialization", + "MLIRTransformDialect", + "MLIRSparseTensorPipelines", + "MLIRVectorToGPU", + "MLIRTranslateLib", + "MLIRPass", + "MLIRComplexToLibm", + "MLIRInferTypeOpInterface", + "MLIRMemRefToSPIRV", + "MLIRAMDGPUToROCDL", + "MLIRBufferizationTransformOps", + "MLIRExecutionEngineUtils", + "MLIRNVVMDialect", + "MLIRSCFUtils", + "MLIRLinalgTransforms", + "MLIRParser", + "MLIRFuncTransforms", + "MLIRTosaTestPasses", + "MLIRTosaToArith", + "MLIRTensorDialect", + "MLIRGPUTransforms", + "MLIRLowerableDialectsToLLVM", + "MLIRBufferizationToMemRef", + "MLIRPresburger", + "MLIRFuncDialect", + "MLIRPDLToPDLInterp", + "MLIRArithmeticTransforms", + "MLIRViewLikeInterface", + "MLIRTargetCpp", + "MLIROpenMPToLLVM", + "MLIRSPIRVConversion", + "MLIRNVGPUTransforms", + "MLIRSparseTensorTransforms", + "MLIRAffineAnalysis", + "MLIRArmSVETransforms", + "MLIRArmNeon2dToIntr", + "MLIRDataLayoutInterfaces", + "MLIRAffineTransforms", + "MLIROpenACCToLLVMIRTranslation", + "MLIRTensorUtils", + "MLIRSPIRVSerialization", + "MLIRShapeToStandard", + "MLIRArithmeticToSPIRV", + "MLIRArithmeticDialect", + "MLIRFuncToSPIRV", + "MLIRQuantUtils", + "MLIRTensorTilingInterfaceImpl", + "MLIRX86VectorToLLVMIRTranslation", + "MLIRCopyOpInterface", + "MLIRMathToLibm", + "MLIRGPUToGPURuntimeTransforms", + "MLIRLLVMDialect", + "MLIRAffineDialect", + "MLIRTransforms", + "MLIRVectorTransforms", + "MLIROpenMPDialect", + "MLIRControlFlowDialect", + "MLIRVectorUtils", + "MLIRROCDLDialect", + "MLIRPDLDialect", + "MLIRAsyncDialect", + "MLIRLinalgToLLVM", + "MLIROpenACCDialect", + "MLIRVectorDialect", + "MLIROpenACCToSCF", + "MLIRIR", + "MLIRCAPIIR", + "MLIRTargetLLVMIRImport", + "MLIRTensorToLinalg", + "MLIRCallInterfaces", + "MLIRTensorInferTypeOpInterfaceImpl", + "MLIRTransformDialectTransforms", + "MLIRComplexDialect", + "MLIRAffineUtils", + "MLIRLoopLikeInterface", + "MLIRDialect", + "MLIRLinalgUtils", + "MLIRSCFToSPIRV", + "MLIRAffineToStandard", + "MLIRX86VectorDialect", + "MLIRGPUToVulkanTransforms", + "MLIRRewrite", + "MLIRAMXToLLVMIRTranslation", + "MLIRInferIntRangeInterface", + "MLIRCAPIRegistration", + "MLIRNVVMToLLVMIRTranslation", + "MLIRAsyncTransforms", + "MLIRPDLInterpDialect", + "MLIRTransformUtils", + "MLIRLinalgDialect", + "MLIRMathDialect", + "MLIRMemRefTransforms", + "MLIRSPIRVModuleCombiner", + "MLIRMathToLLVM", + "MLIRControlFlowToLLVM", + "MLIRArmSVEDialect", + "MLIRSPIRVTranslateRegistration", + "MLIRToLLVMIRTranslationRegistration", + "MLIRSCFDialect", + "MLIRTilingInterface", + "MLIREmitCDialect", + "MLIRTableGen", + "MLIRTosaToSCF", + "MLIROpenMPToLLVMIRTranslation", + "MLIRSupport", + "MLIROpenACCToLLVM", + "MLIRAMDGPUDialect", + "MLIRTosaToLinalg", + "MLIRSparseTensorUtils", + "MLIRFuncToLLVM", + "MLIRTargetLLVMIRExport", + "MLIRControlFlowToSPIRV", + "MLIRReconcileUnrealizedCasts", + "MLIRComplexToStandard", + "MLIRMathTransforms", + "MLIRSPIRVUtils", + "MLIRCastInterfaces", + "MLIRTosaToTensor", + "MLIRMemRefUtils", + "MLIRGPUToSPIRV", + "MLIRBufferizationDialect", + "MLIRSCFToControlFlow", + "MLIRArmSVEToLLVMIRTranslation", + "MLIRExecutionEngine", + "MLIRBufferizationTransforms", + "MLIRSparseTensorDialect", + "MLIRTensorToSPIRV", + "MLIRVectorToSCF", + "MLIRQuantTransforms", + "MLIRLLVMToLLVMIRTranslation", + "MLIRNVGPUDialect", + "MLIRAsyncToLLVM", + "MLIRAMXDialect", + "MLIRLinalgTransformOps", + "MLIRMathToSPIRV", + "MLIRSCFToOpenMP", + "MLIRShapeDialect", + "MLIRGPUToROCDLTransforms", + "MLIRGPUToNVVMTransforms", + "MLIRTensorTransforms", + "MLIRSCFToGPU", + "MLIRDialectUtils", + "MLIRNVGPUToNVVM", + "MLIRTosaDialect", + "MLIRVectorToLLVM", + "MLIRSPIRVDialect", + "MLIRSideEffectInterfaces", + "MLIRVectorToROCDL", + "MLIRQuantDialect", + "MLIRSCFTransforms", + "MLIRMLProgramDialect", + "MLIRLinalgToSPIRV", + "MLIRDLTIDialect", + "MLIRLinalgFrontend", + "MLIRROCDLToLLVMIRTranslation", + "MLIRArmNeonDialect", + "MLIRSPIRVToLLVM", + "MLIRLLVMIRTransforms", + "MLIRTosaTransforms", + "MLIRLLVMCommonConversion", + "MLIRSCFTransformOps", + "MLIRArmNeonToLLVMIRTranslation", + "MLIRAMXTransforms", + "MLIRSPIRVTransforms", + "MLIRMemRefToLLVM", + "MLIRSPIRVBinaryUtils", + "MLIRLinalgAnalysis", + "MLIRArithmeticUtils", + "MLIRVectorInterfaces", + "MLIRGPUOps", + "MLIRComplexToLLVM", + "MLIRShapeOpsTransforms", + "MLIRX86VectorTransforms", + "MLIRArithmeticToLLVM", +]; + +fn main() { + if let Err(error) = run() { + eprintln!("{}", error); + exit(1); + } +} + +fn run() -> Result<(), Box> { + // library paths + let build_dir = get_build_dir(); + let lib_dir = build_dir.join("lib"); + println!("cargo:rustc-link-search={}", lib_dir.to_str().unwrap()); + + // include paths + let root = std::fs::canonicalize("../../../../")?; + let include_paths = [ + // compiler build + build_dir.join("tools/concretelang/include/"), + // mlir build + build_dir.join("tools/mlir/include"), + // llvm build + build_dir.join("include"), + // compiler + root.join("compiler/include/"), + // mlir + root.join("llvm-project/mlir/include/"), + // llvm + root.join("llvm-project/llvm/include/"), + // concrete-optimizer + root.join("compiler/concrete-optimizer/concrete-optimizer-cpp/src/cpp/"), + ]; + + // linking + for mlir_static_lib in MLIR_STATIC_LIBS { + println!("cargo:rustc-link-lib=static={}", mlir_static_lib); + } + println!("cargo:rustc-link-lib=static=LLVMSupport"); + println!("cargo:rustc-link-lib=static=LLVMCore"); + // required by llvm + println!("cargo:rustc-link-lib=tinfo"); + if let Some(name) = get_system_libcpp() { + println!("cargo:rustc-link-lib={}", name); + } + + println!("cargo:rerun-if-changed=api.h"); + bindgen::builder() + .header("api.h") + .clang_args( + include_paths + .into_iter() + .map(|path| format!("-I{}", path.to_str().unwrap())), + ) + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + .generate() + .unwrap() + .write_to_file(Path::new(&env::var("OUT_DIR")?).join("bindings.rs"))?; + + Ok(()) +} + +fn get_system_libcpp() -> Option<&'static str> { + if cfg!(target_env = "msvc") { + None + } else if cfg!(target_os = "macos") { + Some("c++") + } else { + Some("stdc++") + } +} + +fn get_build_dir() -> PathBuf { + // this env variable can be used to point to a different build directory + let build_dir = match env::var("CONCRETE_COMPILER_BUILD_DIR") { + Ok(val) => std::path::Path::new(&val).to_path_buf(), + Err(_e) => std::path::Path::new(".") + .parent() + .unwrap() + .join("..") + .join("..") + .join("..") + .join("build"), + }; + return build_dir; +} diff --git a/compiler/lib/Bindings/Rust/src/lib.rs b/compiler/lib/Bindings/Rust/src/lib.rs new file mode 100644 index 000000000..c64843e25 --- /dev/null +++ b/compiler/lib/Bindings/Rust/src/lib.rs @@ -0,0 +1,172 @@ +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] + +include!(concat!(env!("OUT_DIR"), "/bindings.rs")); + +#[cfg(test)] +mod concrete_compiler_tests { + use super::*; + use std::ffi::CString; + use std::fs; + use stdio_override::StderrOverride; + use tempfile::tempdir; + + #[test] + fn test_function_type() { + unsafe { + let context = mlirContextCreate(); + let func_type = mlirFunctionTypeGet(context, 0, std::ptr::null(), 0, std::ptr::null()); + assert!(mlirTypeIsAFunction(func_type)); + mlirContextDestroy(context); + } + } + + #[test] + fn test_module_parsing() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + let module_string = " + module{ + func.func @test(%arg0: i64, %arg1: i64) -> i64 { + %1 = arith.addi %arg0, %arg1 : i64 + return %1: i64 + } + }"; + let module_cstring = CString::new(module_string).unwrap(); + let module_reference = mlirStringRefCreateFromCString(module_cstring.as_ptr()); + + let parsed_module = mlirModuleCreateParse(context, module_reference); + let parsed_func = mlirBlockGetFirstOperation(mlirModuleGetBody(parsed_module)); + + let func_type_str = CString::new("function_type").unwrap(); + // just check that we do have a function here, which should be enough to know that parsing worked well + assert!(mlirTypeIsAFunction(mlirTypeAttrGetValue( + mlirOperationGetAttributeByName( + parsed_func, + mlirStringRefCreateFromCString(func_type_str.as_ptr()), + ) + ))); + } + } + + #[test] + fn test_module_creation() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + let location = mlirLocationUnknownGet(context); + + // input/output types + let func_input_types = [ + mlirIntegerTypeGet(context, 64), + mlirIntegerTypeGet(context, 64), + ]; + let func_output_type = [mlirIntegerTypeGet(context, 64)]; + + // create the main block of the function + let func_block = mlirBlockCreate(2, func_input_types.as_ptr(), &location); + + let location = mlirLocationUnknownGet(context); + // create addi operation and append it to the block + let addi_str = CString::new("arith.addi").unwrap(); + let mut addi_op_state = + mlirOperationStateGet(mlirStringRefCreateFromCString(addi_str.as_ptr()), location); + mlirOperationStateAddOperands( + &mut addi_op_state, + 2, + [ + mlirBlockGetArgument(func_block, 0), + mlirBlockGetArgument(func_block, 1), + ] + .as_ptr(), + ); + mlirOperationStateEnableResultTypeInference(&mut addi_op_state); + let addi_op = mlirOperationCreate(&mut addi_op_state); + mlirBlockAppendOwnedOperation(func_block, addi_op); + + // create return operation and append it to the block + let ret_str = CString::new("func.return").unwrap(); + let mut ret_op_state = + mlirOperationStateGet(mlirStringRefCreateFromCString(ret_str.as_ptr()), location); + mlirOperationStateAddOperands( + &mut ret_op_state, + 1, + [mlirOperationGetResult(addi_op, 0)].as_ptr(), + ); + let ret_op = mlirOperationCreate(&mut ret_op_state); + mlirBlockAppendOwnedOperation(func_block, ret_op); + + // create region to hold the previously created block + let func_region = mlirRegionCreate(); + mlirRegionAppendOwnedBlock(func_region, func_block); + + // create function to hold the previously created region + let func_str = CString::new("func.func").unwrap(); + let mut func_op_state = + mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location); + mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr()); + // set function attributes + let func_type_str = CString::new("function_type").unwrap(); + let sym_name_str = CString::new("sym_name").unwrap(); + let func_name_str = CString::new("test").unwrap(); + let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet( + context, + 2, + func_input_types.as_ptr(), + 1, + func_output_type.as_ptr(), + )); + let sym_name_attr = mlirStringAttrGet( + context, + mlirStringRefCreateFromCString(func_name_str.as_ptr()), + ); + mlirOperationStateAddAttributes( + &mut func_op_state, + 2, + [ + // func type + mlirNamedAttributeGet( + mlirIdentifierGet( + context, + mlirStringRefCreateFromCString(func_type_str.as_ptr()), + ), + func_type_attr, + ), + // func name + mlirNamedAttributeGet( + mlirIdentifierGet( + context, + mlirStringRefCreateFromCString(sym_name_str.as_ptr()), + ), + sym_name_attr, + ), + ] + .as_ptr(), + ); + let func_op = mlirOperationCreate(&mut func_op_state); + + // create module to hold the previously created function + let module = mlirModuleCreateEmpty(location); + mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op); + + // dump module to stderr and capture it in a temp file + let temp_dir = tempdir().unwrap(); + let temp_file = temp_dir.path().join("concrete_compiler_test_stderr"); + let guard = StderrOverride::override_file(&temp_file).unwrap(); + mlirOperationDump(mlirModuleGetOperation(module)); + let printed_module = fs::read_to_string(&temp_file).unwrap(); + drop(guard); + + // assert that textual representation of the created module is as expected + let expected_filename = std::path::Path::new(".") + .parent() + .unwrap() + .join("tests") + .join("test_module_creation_expected.mlir"); + let expected_module = fs::read_to_string(expected_filename).unwrap(); + assert!(printed_module.eq(expected_module.as_str())); + } + } +} diff --git a/compiler/lib/Bindings/Rust/tests/test_module_creation_expected.mlir b/compiler/lib/Bindings/Rust/tests/test_module_creation_expected.mlir new file mode 100644 index 000000000..c8394bad6 --- /dev/null +++ b/compiler/lib/Bindings/Rust/tests/test_module_creation_expected.mlir @@ -0,0 +1,6 @@ +module { + func.func @test(%arg0: i64, %arg1: i64) -> i64 { + %0 = arith.addi %arg0, %arg1 : i64 + return %0 : i64 + } +}