feat: add rust bindings

The rust bindings are intented to access both LLVM/MLIR CAPI as well as
the concrete-compiler one. This initial commit provide the API for
LLVM/MLIR only. Tests should be used as an example to how to generate a
valid DAG of operations in MLIR.
This commit is contained in:
youben11
2022-10-27 15:04:56 +01:00
committed by Ayoub Benaissa
parent 0493030033
commit 0ac21fd037
8 changed files with 516 additions and 0 deletions

2
.gitignore vendored
View File

@@ -51,3 +51,5 @@ _build/
compiler/tests/TestLib/out/
compiler/lib/Bindings/Rust/target/
compiler/lib/Bindings/Rust/Cargo.lock

View File

@@ -144,6 +144,13 @@ python-bindings: build-initialized
cmake --build $(BUILD_DIR) --target ConcretelangMLIRPythonModules
cmake --build $(BUILD_DIR) --target ConcretelangPythonModules
# MLIRCAPIRegistration is currently required for linking rust bindings
# This may fade if we can represent it somewhere else, mainly, we may be able to define a single
# lib that the rust bindings need to link to, while that lib contains all necessary libs
rust-bindings: build-initialized concretecompiler
cmake --build $(BUILD_DIR) --target MLIRCAPIRegistration
cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo build --release
clientlib: build-initialized
cmake --build $(BUILD_DIR) --target ConcretelangClientLib
@@ -198,6 +205,10 @@ run-python-tests: python-bindings concretecompiler
test-compiler-file-output: concretecompiler
pytest -vs tests/test_compiler_file_output
## rust-tests
run-rust-tests: rust-bindings
cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo test --release
## end-to-end-tests
build-end-to-end-tests: build-end-to-end-jit-test build-end-to-end-jit-fhe build-end-to-end-jit-encrypted-tensor build-end-to-end-jit-fhelinalg build-end-to-end-jit-lambda

View File

@@ -0,0 +1,11 @@
[package]
name = "concrete_compiler_rust"
version = "0.1.0"
edition = "2021"
[build-dependencies]
bindgen = "0.60.1"
[dev-dependencies]
stdio-override = "0.1.3"
tempfile = "3.3.0"

View File

@@ -0,0 +1,9 @@
# Rust Bindings
A Rust library providing an API to the Concrete Compiler.
### Build
```bash
$ cargo build --release
```

View File

@@ -0,0 +1,32 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir-c/AffineExpr.h>
#include <mlir-c/AffineMap.h>
#include <mlir-c/BuiltinAttributes.h>
#include <mlir-c/BuiltinTypes.h>
#include <mlir-c/Conversion.h>
#include <mlir-c/Debug.h>
#include <mlir-c/Diagnostics.h>
#include <mlir-c/Dialect/Async.h>
#include <mlir-c/Dialect/ControlFlow.h>
#include <mlir-c/Dialect/Func.h>
#include <mlir-c/Dialect/GPU.h>
#include <mlir-c/Dialect/LLVM.h>
#include <mlir-c/Dialect/Linalg.h>
#include <mlir-c/Dialect/PDL.h>
#include <mlir-c/Dialect/Quant.h>
#include <mlir-c/Dialect/SCF.h>
#include <mlir-c/Dialect/Shape.h>
#include <mlir-c/Dialect/SparseTensor.h>
#include <mlir-c/Dialect/Tensor.h>
#include <mlir-c/ExecutionEngine.h>
#include <mlir-c/IR.h>
#include <mlir-c/IntegerSet.h>
#include <mlir-c/Interfaces.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Registration.h>
#include <mlir-c/Support.h>
#include <mlir-c/Transforms.h>

View File

@@ -0,0 +1,273 @@
extern crate bindgen;
use std::env;
use std::error::Error;
use std::path::{Path, PathBuf};
use std::process::exit;
const MLIR_STATIC_LIBS: [&str; 179] = [
"MLIRMemRefDialect",
"MLIRVectorToSPIRV",
"MLIRControlFlowInterfaces",
"MLIRLinalgToStandard",
"MLIRAnalysis",
"MLIRSPIRVDeserialization",
"MLIRTransformDialect",
"MLIRSparseTensorPipelines",
"MLIRVectorToGPU",
"MLIRTranslateLib",
"MLIRPass",
"MLIRComplexToLibm",
"MLIRInferTypeOpInterface",
"MLIRMemRefToSPIRV",
"MLIRAMDGPUToROCDL",
"MLIRBufferizationTransformOps",
"MLIRExecutionEngineUtils",
"MLIRNVVMDialect",
"MLIRSCFUtils",
"MLIRLinalgTransforms",
"MLIRParser",
"MLIRFuncTransforms",
"MLIRTosaTestPasses",
"MLIRTosaToArith",
"MLIRTensorDialect",
"MLIRGPUTransforms",
"MLIRLowerableDialectsToLLVM",
"MLIRBufferizationToMemRef",
"MLIRPresburger",
"MLIRFuncDialect",
"MLIRPDLToPDLInterp",
"MLIRArithmeticTransforms",
"MLIRViewLikeInterface",
"MLIRTargetCpp",
"MLIROpenMPToLLVM",
"MLIRSPIRVConversion",
"MLIRNVGPUTransforms",
"MLIRSparseTensorTransforms",
"MLIRAffineAnalysis",
"MLIRArmSVETransforms",
"MLIRArmNeon2dToIntr",
"MLIRDataLayoutInterfaces",
"MLIRAffineTransforms",
"MLIROpenACCToLLVMIRTranslation",
"MLIRTensorUtils",
"MLIRSPIRVSerialization",
"MLIRShapeToStandard",
"MLIRArithmeticToSPIRV",
"MLIRArithmeticDialect",
"MLIRFuncToSPIRV",
"MLIRQuantUtils",
"MLIRTensorTilingInterfaceImpl",
"MLIRX86VectorToLLVMIRTranslation",
"MLIRCopyOpInterface",
"MLIRMathToLibm",
"MLIRGPUToGPURuntimeTransforms",
"MLIRLLVMDialect",
"MLIRAffineDialect",
"MLIRTransforms",
"MLIRVectorTransforms",
"MLIROpenMPDialect",
"MLIRControlFlowDialect",
"MLIRVectorUtils",
"MLIRROCDLDialect",
"MLIRPDLDialect",
"MLIRAsyncDialect",
"MLIRLinalgToLLVM",
"MLIROpenACCDialect",
"MLIRVectorDialect",
"MLIROpenACCToSCF",
"MLIRIR",
"MLIRCAPIIR",
"MLIRTargetLLVMIRImport",
"MLIRTensorToLinalg",
"MLIRCallInterfaces",
"MLIRTensorInferTypeOpInterfaceImpl",
"MLIRTransformDialectTransforms",
"MLIRComplexDialect",
"MLIRAffineUtils",
"MLIRLoopLikeInterface",
"MLIRDialect",
"MLIRLinalgUtils",
"MLIRSCFToSPIRV",
"MLIRAffineToStandard",
"MLIRX86VectorDialect",
"MLIRGPUToVulkanTransforms",
"MLIRRewrite",
"MLIRAMXToLLVMIRTranslation",
"MLIRInferIntRangeInterface",
"MLIRCAPIRegistration",
"MLIRNVVMToLLVMIRTranslation",
"MLIRAsyncTransforms",
"MLIRPDLInterpDialect",
"MLIRTransformUtils",
"MLIRLinalgDialect",
"MLIRMathDialect",
"MLIRMemRefTransforms",
"MLIRSPIRVModuleCombiner",
"MLIRMathToLLVM",
"MLIRControlFlowToLLVM",
"MLIRArmSVEDialect",
"MLIRSPIRVTranslateRegistration",
"MLIRToLLVMIRTranslationRegistration",
"MLIRSCFDialect",
"MLIRTilingInterface",
"MLIREmitCDialect",
"MLIRTableGen",
"MLIRTosaToSCF",
"MLIROpenMPToLLVMIRTranslation",
"MLIRSupport",
"MLIROpenACCToLLVM",
"MLIRAMDGPUDialect",
"MLIRTosaToLinalg",
"MLIRSparseTensorUtils",
"MLIRFuncToLLVM",
"MLIRTargetLLVMIRExport",
"MLIRControlFlowToSPIRV",
"MLIRReconcileUnrealizedCasts",
"MLIRComplexToStandard",
"MLIRMathTransforms",
"MLIRSPIRVUtils",
"MLIRCastInterfaces",
"MLIRTosaToTensor",
"MLIRMemRefUtils",
"MLIRGPUToSPIRV",
"MLIRBufferizationDialect",
"MLIRSCFToControlFlow",
"MLIRArmSVEToLLVMIRTranslation",
"MLIRExecutionEngine",
"MLIRBufferizationTransforms",
"MLIRSparseTensorDialect",
"MLIRTensorToSPIRV",
"MLIRVectorToSCF",
"MLIRQuantTransforms",
"MLIRLLVMToLLVMIRTranslation",
"MLIRNVGPUDialect",
"MLIRAsyncToLLVM",
"MLIRAMXDialect",
"MLIRLinalgTransformOps",
"MLIRMathToSPIRV",
"MLIRSCFToOpenMP",
"MLIRShapeDialect",
"MLIRGPUToROCDLTransforms",
"MLIRGPUToNVVMTransforms",
"MLIRTensorTransforms",
"MLIRSCFToGPU",
"MLIRDialectUtils",
"MLIRNVGPUToNVVM",
"MLIRTosaDialect",
"MLIRVectorToLLVM",
"MLIRSPIRVDialect",
"MLIRSideEffectInterfaces",
"MLIRVectorToROCDL",
"MLIRQuantDialect",
"MLIRSCFTransforms",
"MLIRMLProgramDialect",
"MLIRLinalgToSPIRV",
"MLIRDLTIDialect",
"MLIRLinalgFrontend",
"MLIRROCDLToLLVMIRTranslation",
"MLIRArmNeonDialect",
"MLIRSPIRVToLLVM",
"MLIRLLVMIRTransforms",
"MLIRTosaTransforms",
"MLIRLLVMCommonConversion",
"MLIRSCFTransformOps",
"MLIRArmNeonToLLVMIRTranslation",
"MLIRAMXTransforms",
"MLIRSPIRVTransforms",
"MLIRMemRefToLLVM",
"MLIRSPIRVBinaryUtils",
"MLIRLinalgAnalysis",
"MLIRArithmeticUtils",
"MLIRVectorInterfaces",
"MLIRGPUOps",
"MLIRComplexToLLVM",
"MLIRShapeOpsTransforms",
"MLIRX86VectorTransforms",
"MLIRArithmeticToLLVM",
];
fn main() {
if let Err(error) = run() {
eprintln!("{}", error);
exit(1);
}
}
fn run() -> Result<(), Box<dyn Error>> {
// library paths
let build_dir = get_build_dir();
let lib_dir = build_dir.join("lib");
println!("cargo:rustc-link-search={}", lib_dir.to_str().unwrap());
// include paths
let root = std::fs::canonicalize("../../../../")?;
let include_paths = [
// compiler build
build_dir.join("tools/concretelang/include/"),
// mlir build
build_dir.join("tools/mlir/include"),
// llvm build
build_dir.join("include"),
// compiler
root.join("compiler/include/"),
// mlir
root.join("llvm-project/mlir/include/"),
// llvm
root.join("llvm-project/llvm/include/"),
// concrete-optimizer
root.join("compiler/concrete-optimizer/concrete-optimizer-cpp/src/cpp/"),
];
// linking
for mlir_static_lib in MLIR_STATIC_LIBS {
println!("cargo:rustc-link-lib=static={}", mlir_static_lib);
}
println!("cargo:rustc-link-lib=static=LLVMSupport");
println!("cargo:rustc-link-lib=static=LLVMCore");
// required by llvm
println!("cargo:rustc-link-lib=tinfo");
if let Some(name) = get_system_libcpp() {
println!("cargo:rustc-link-lib={}", name);
}
println!("cargo:rerun-if-changed=api.h");
bindgen::builder()
.header("api.h")
.clang_args(
include_paths
.into_iter()
.map(|path| format!("-I{}", path.to_str().unwrap())),
)
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
.generate()
.unwrap()
.write_to_file(Path::new(&env::var("OUT_DIR")?).join("bindings.rs"))?;
Ok(())
}
fn get_system_libcpp() -> Option<&'static str> {
if cfg!(target_env = "msvc") {
None
} else if cfg!(target_os = "macos") {
Some("c++")
} else {
Some("stdc++")
}
}
fn get_build_dir() -> PathBuf {
// this env variable can be used to point to a different build directory
let build_dir = match env::var("CONCRETE_COMPILER_BUILD_DIR") {
Ok(val) => std::path::Path::new(&val).to_path_buf(),
Err(_e) => std::path::Path::new(".")
.parent()
.unwrap()
.join("..")
.join("..")
.join("..")
.join("build"),
};
return build_dir;
}

View File

@@ -0,0 +1,172 @@
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
#[cfg(test)]
mod concrete_compiler_tests {
use super::*;
use std::ffi::CString;
use std::fs;
use stdio_override::StderrOverride;
use tempfile::tempdir;
#[test]
fn test_function_type() {
unsafe {
let context = mlirContextCreate();
let func_type = mlirFunctionTypeGet(context, 0, std::ptr::null(), 0, std::ptr::null());
assert!(mlirTypeIsAFunction(func_type));
mlirContextDestroy(context);
}
}
#[test]
fn test_module_parsing() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
let module_string = "
module{
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
%1 = arith.addi %arg0, %arg1 : i64
return %1: i64
}
}";
let module_cstring = CString::new(module_string).unwrap();
let module_reference = mlirStringRefCreateFromCString(module_cstring.as_ptr());
let parsed_module = mlirModuleCreateParse(context, module_reference);
let parsed_func = mlirBlockGetFirstOperation(mlirModuleGetBody(parsed_module));
let func_type_str = CString::new("function_type").unwrap();
// just check that we do have a function here, which should be enough to know that parsing worked well
assert!(mlirTypeIsAFunction(mlirTypeAttrGetValue(
mlirOperationGetAttributeByName(
parsed_func,
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
)
)));
}
}
#[test]
fn test_module_creation() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
let location = mlirLocationUnknownGet(context);
// input/output types
let func_input_types = [
mlirIntegerTypeGet(context, 64),
mlirIntegerTypeGet(context, 64),
];
let func_output_type = [mlirIntegerTypeGet(context, 64)];
// create the main block of the function
let func_block = mlirBlockCreate(2, func_input_types.as_ptr(), &location);
let location = mlirLocationUnknownGet(context);
// create addi operation and append it to the block
let addi_str = CString::new("arith.addi").unwrap();
let mut addi_op_state =
mlirOperationStateGet(mlirStringRefCreateFromCString(addi_str.as_ptr()), location);
mlirOperationStateAddOperands(
&mut addi_op_state,
2,
[
mlirBlockGetArgument(func_block, 0),
mlirBlockGetArgument(func_block, 1),
]
.as_ptr(),
);
mlirOperationStateEnableResultTypeInference(&mut addi_op_state);
let addi_op = mlirOperationCreate(&mut addi_op_state);
mlirBlockAppendOwnedOperation(func_block, addi_op);
// create return operation and append it to the block
let ret_str = CString::new("func.return").unwrap();
let mut ret_op_state =
mlirOperationStateGet(mlirStringRefCreateFromCString(ret_str.as_ptr()), location);
mlirOperationStateAddOperands(
&mut ret_op_state,
1,
[mlirOperationGetResult(addi_op, 0)].as_ptr(),
);
let ret_op = mlirOperationCreate(&mut ret_op_state);
mlirBlockAppendOwnedOperation(func_block, ret_op);
// create region to hold the previously created block
let func_region = mlirRegionCreate();
mlirRegionAppendOwnedBlock(func_region, func_block);
// create function to hold the previously created region
let func_str = CString::new("func.func").unwrap();
let mut func_op_state =
mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location);
mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr());
// set function attributes
let func_type_str = CString::new("function_type").unwrap();
let sym_name_str = CString::new("sym_name").unwrap();
let func_name_str = CString::new("test").unwrap();
let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet(
context,
2,
func_input_types.as_ptr(),
1,
func_output_type.as_ptr(),
));
let sym_name_attr = mlirStringAttrGet(
context,
mlirStringRefCreateFromCString(func_name_str.as_ptr()),
);
mlirOperationStateAddAttributes(
&mut func_op_state,
2,
[
// func type
mlirNamedAttributeGet(
mlirIdentifierGet(
context,
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
),
func_type_attr,
),
// func name
mlirNamedAttributeGet(
mlirIdentifierGet(
context,
mlirStringRefCreateFromCString(sym_name_str.as_ptr()),
),
sym_name_attr,
),
]
.as_ptr(),
);
let func_op = mlirOperationCreate(&mut func_op_state);
// create module to hold the previously created function
let module = mlirModuleCreateEmpty(location);
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
// dump module to stderr and capture it in a temp file
let temp_dir = tempdir().unwrap();
let temp_file = temp_dir.path().join("concrete_compiler_test_stderr");
let guard = StderrOverride::override_file(&temp_file).unwrap();
mlirOperationDump(mlirModuleGetOperation(module));
let printed_module = fs::read_to_string(&temp_file).unwrap();
drop(guard);
// assert that textual representation of the created module is as expected
let expected_filename = std::path::Path::new(".")
.parent()
.unwrap()
.join("tests")
.join("test_module_creation_expected.mlir");
let expected_module = fs::read_to_string(expected_filename).unwrap();
assert!(printed_module.eq(expected_module.as_str()));
}
}
}

View File

@@ -0,0 +1,6 @@
module {
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
%0 = arith.addi %arg0, %arg1 : i64
return %0 : i64
}
}