feat(rust): enhance API to create func with FHE dialect

This commit is contained in:
youben11
2022-11-04 16:35:55 +01:00
committed by Ayoub Benaissa
parent eabd8b959d
commit 472e762fbf
4 changed files with 307 additions and 89 deletions

View File

@@ -147,10 +147,13 @@ python-bindings: build-initialized
# MLIRCAPIRegistration is currently required for linking rust bindings
# This may fade if we can represent it somewhere else, mainly, we may be able to define a single
# lib that the rust bindings need to link to, while that lib contains all necessary libs
rust-bindings: build-initialized concretecompiler
rust-bindings: build-initialized concretecompiler CAPI
cmake --build $(BUILD_DIR) --target MLIRCAPIRegistration
cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo build --release
CAPI:
cmake --build $(BUILD_DIR) --target CONCRETELANGCAPIFHE CONCRETELANGCAPIFHELINALG
clientlib: build-initialized
cmake --build $(BUILD_DIR) --target ConcretelangClientLib

View File

@@ -3,6 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <concretelang-c/Dialect/FHE.h>
#include <mlir-c/AffineExpr.h>
#include <mlir-c/AffineMap.h>
#include <mlir-c/BuiltinAttributes.h>

View File

@@ -230,6 +230,9 @@ fn run() -> Result<(), Box<dyn Error>> {
if let Some(name) = get_system_libcpp() {
println!("cargo:rustc-link-lib={}", name);
}
// concrete-compiler libs
println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHE");
println!("cargo:rustc-link-lib=static=FHEDialect");
println!("cargo:rerun-if-changed=api.h");
bindgen::builder()

View File

@@ -2,6 +2,7 @@
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
use std::ffi::CString;
use std::ops::AddAssign;
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
@@ -20,21 +21,208 @@ pub fn print_mlir_operation_to_string(op: MlirOperation) -> String {
let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void;
unsafe {
mlirOperationPrint(
op,
mlirOperationPrint(op, Some(mlir_rust_string_receiver_callback), receiver_ptr);
}
rust_string
}
pub fn print_mlir_type_to_string(mlir_type: MlirType) -> String {
let mut rust_string = String::default();
let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void;
unsafe {
mlirTypePrint(
mlir_type,
Some(mlir_rust_string_receiver_callback),
receiver_ptr
receiver_ptr,
);
}
rust_string
}
/// Returns a function operation with a region that contains a block.
///
/// The function would be defined using the provided input and output types. The main block of the
/// function can be later fetched, from which we can get function arguments, and it will be where
/// we append operations.
///
/// # Examples
/// ```
/// use concrete_compiler_rust::*;
/// unsafe{
/// let context = mlirContextCreate();
/// mlirRegisterAllDialects(context);
///
/// // input/output types
/// let func_input_types = [
/// mlirIntegerTypeGet(context, 64),
/// mlirIntegerTypeGet(context, 64),
/// ];
/// let func_output_types = [mlirIntegerTypeGet(context, 64)];
///
/// let func_op = create_func_with_block(
/// context,
/// "test",
/// func_input_types.as_slice(),
/// func_output_types.as_slice(),
/// );
///
/// // we can fetch the main block of the function from the function region
/// let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
/// // we can get arguments to later be used as operands to other operations
/// let func_args = [
/// mlirBlockGetArgument(func_block, 0),
/// mlirBlockGetArgument(func_block, 1),
/// ];
/// // to add an operation to the function, we will append it to the main block
/// let addi_op = create_addi_op(context, func_args[0], func_args[1]);
/// mlirBlockAppendOwnedOperation(func_block, addi_op);
/// }
/// ```
///
pub fn create_func_with_block(
context: MlirContext,
func_name: &str,
func_input_types: &[MlirType],
func_output_types: &[MlirType],
) -> MlirOperation {
unsafe {
// create the main block of the function
let location = mlirLocationUnknownGet(context);
let func_block = mlirBlockCreate(
func_input_types.len().try_into().unwrap(),
func_input_types.as_ptr(),
&location,
);
// create region to hold the previously created block
let func_region = mlirRegionCreate();
mlirRegionAppendOwnedBlock(func_region, func_block);
// create function to hold the previously created region
let location = mlirLocationUnknownGet(context);
let func_str = CString::new("func.func").unwrap();
let mut func_op_state =
mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location);
mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr());
// set function attributes
let func_type_str = CString::new("function_type").unwrap();
let sym_name_str = CString::new("sym_name").unwrap();
let func_name_str = CString::new(func_name).unwrap();
let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet(
context,
func_input_types.len().try_into().unwrap(),
func_input_types.as_ptr(),
func_output_types.len().try_into().unwrap(),
func_output_types.as_ptr(),
));
let sym_name_attr = mlirStringAttrGet(
context,
mlirStringRefCreateFromCString(func_name_str.as_ptr()),
);
mlirOperationStateAddAttributes(
&mut func_op_state,
2,
[
// func type
mlirNamedAttributeGet(
mlirIdentifierGet(
context,
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
),
func_type_attr,
),
// func name
mlirNamedAttributeGet(
mlirIdentifierGet(
context,
mlirStringRefCreateFromCString(sym_name_str.as_ptr()),
),
sym_name_attr,
),
]
.as_ptr(),
);
let func_op = mlirOperationCreate(&mut func_op_state);
func_op
}
}
/// Generic function to create an MLIR operation.
///
/// Create an MLIR operation based on its mnemonic (e.g. addi), it's operands, result types, and
/// attributes. Result types can be inferred automatically if the operation itself supports that.
pub fn create_op(
context: MlirContext,
mnemonic: &str,
operands: &[MlirValue],
results: &[MlirType],
attrs: &[MlirNamedAttribute],
auto_result_type_inference: bool,
) -> MlirOperation {
let op_mnemonic = CString::new(mnemonic).unwrap();
unsafe {
let location = mlirLocationUnknownGet(context);
let mut op_state = mlirOperationStateGet(
mlirStringRefCreateFromCString(op_mnemonic.as_ptr()),
location,
);
mlirOperationStateAddOperands(
&mut op_state,
operands.len().try_into().unwrap(),
operands.as_ptr(),
);
mlirOperationStateAddAttributes(
&mut op_state,
attrs.len().try_into().unwrap(),
attrs.as_ptr(),
);
if auto_result_type_inference {
mlirOperationStateEnableResultTypeInference(&mut op_state);
} else {
mlirOperationStateAddResults(
&mut op_state,
results.len().try_into().unwrap(),
results.as_ptr(),
);
}
mlirOperationCreate(&mut op_state)
}
}
pub fn create_fhe_add_eint_op(
context: MlirContext,
lhs: MlirValue,
rhs: MlirValue,
) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(lhs)];
// infer result type from operands
create_op(
context,
"FHE.add_eint",
&[lhs, rhs],
results.as_slice(),
&[],
false,
)
}
}
pub fn create_addi_op(context: MlirContext, lhs: MlirValue, rhs: MlirValue) -> MlirOperation {
create_op(context, "arith.addi", &[lhs, rhs], &[], &[], true)
}
pub fn create_ret_op(context: MlirContext, ret_value: MlirValue) -> MlirOperation {
create_op(context, "func.return", &[ret_value], &[], &[], false)
}
#[cfg(test)]
mod concrete_compiler_tests {
use super::*;
use std::ffi::CString;
#[test]
fn test_function_type() {
@@ -80,102 +268,41 @@ mod concrete_compiler_tests {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
let location = mlirLocationUnknownGet(context);
// input/output types
let func_input_types = [
mlirIntegerTypeGet(context, 64),
mlirIntegerTypeGet(context, 64),
];
let func_output_type = [mlirIntegerTypeGet(context, 64)];
let func_output_types = [mlirIntegerTypeGet(context, 64)];
// create the main block of the function
let func_block = mlirBlockCreate(2, func_input_types.as_ptr(), &location);
let location = mlirLocationUnknownGet(context);
// create addi operation and append it to the block
let addi_str = CString::new("arith.addi").unwrap();
let mut addi_op_state =
mlirOperationStateGet(mlirStringRefCreateFromCString(addi_str.as_ptr()), location);
mlirOperationStateAddOperands(
&mut addi_op_state,
2,
[
mlirBlockGetArgument(func_block, 0),
mlirBlockGetArgument(func_block, 1),
]
.as_ptr(),
let func_op = create_func_with_block(
context,
"test",
func_input_types.as_slice(),
func_output_types.as_slice(),
);
mlirOperationStateEnableResultTypeInference(&mut addi_op_state);
let addi_op = mlirOperationCreate(&mut addi_op_state);
let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
let func_args = [
mlirBlockGetArgument(func_block, 0),
mlirBlockGetArgument(func_block, 1),
];
// create addi operation and append it to the block
let addi_op = create_addi_op(context, func_args[0], func_args[1]);
mlirBlockAppendOwnedOperation(func_block, addi_op);
// create return operation and append it to the block
let ret_str = CString::new("func.return").unwrap();
let mut ret_op_state =
mlirOperationStateGet(mlirStringRefCreateFromCString(ret_str.as_ptr()), location);
mlirOperationStateAddOperands(
&mut ret_op_state,
1,
[mlirOperationGetResult(addi_op, 0)].as_ptr(),
);
let ret_op = mlirOperationCreate(&mut ret_op_state);
// create ret operation and append it to the block
let ret_op = create_ret_op(context, mlirOperationGetResult(addi_op, 0));
mlirBlockAppendOwnedOperation(func_block, ret_op);
// create region to hold the previously created block
let func_region = mlirRegionCreate();
mlirRegionAppendOwnedBlock(func_region, func_block);
// create function to hold the previously created region
let func_str = CString::new("func.func").unwrap();
let mut func_op_state =
mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location);
mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr());
// set function attributes
let func_type_str = CString::new("function_type").unwrap();
let sym_name_str = CString::new("sym_name").unwrap();
let func_name_str = CString::new("test").unwrap();
let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet(
context,
2,
func_input_types.as_ptr(),
1,
func_output_type.as_ptr(),
));
let sym_name_attr = mlirStringAttrGet(
context,
mlirStringRefCreateFromCString(func_name_str.as_ptr()),
);
mlirOperationStateAddAttributes(
&mut func_op_state,
2,
[
// func type
mlirNamedAttributeGet(
mlirIdentifierGet(
context,
mlirStringRefCreateFromCString(func_type_str.as_ptr()),
),
func_type_attr,
),
// func name
mlirNamedAttributeGet(
mlirIdentifierGet(
context,
mlirStringRefCreateFromCString(sym_name_str.as_ptr()),
),
sym_name_attr,
),
]
.as_ptr(),
);
let func_op = mlirOperationCreate(&mut func_op_state);
// create module to hold the previously created function
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
let printed_module = super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
let printed_module =
super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
let expected_module = "\
module {
func.func @test(%arg0: i64, %arg1: i64) -> i64 {
@@ -185,11 +312,95 @@ module {
}
";
assert_eq!(
printed_module,
expected_module,
printed_module, expected_module,
"left: \n{}, right: \n{}",
printed_module,
expected_module);
printed_module, expected_module
);
}
}
#[test]
fn test_invalid_fhe_eint_type() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
let invalid_eint = fheEncryptedIntegerTypeGetChecked(context, 0);
assert!(invalid_eint.isError);
}
}
#[test]
fn test_valid_fhe_eint_type() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
assert!(!eint_or_error.isError);
let eint = eint_or_error.type_;
let printed_eint = super::print_mlir_type_to_string(eint);
let expected_eint = "!FHE.eint<5>";
assert_eq!(printed_eint, expected_eint);
}
}
#[test]
fn test_fhe_func() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 5-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
assert!(!eint_or_error.isError);
let eint = eint_or_error.type_;
// set input/output types of the FHE circuit
let func_input_types = [eint, eint];
let func_output_types = [eint];
// create the func operation
let func_op = create_func_with_block(
context,
"main",
func_input_types.as_slice(),
func_output_types.as_slice(),
);
let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op));
let func_args = [
mlirBlockGetArgument(func_block, 0),
mlirBlockGetArgument(func_block, 1),
];
// create an FHE add_eint op and append it to the function block
let add_eint_op = create_fhe_add_eint_op(context, func_args[0], func_args[1]);
mlirBlockAppendOwnedOperation(func_block, add_eint_op);
// create ret operation and append it to the block
let ret_op = create_ret_op(context, mlirOperationGetResult(add_eint_op, 0));
mlirBlockAppendOwnedOperation(func_block, ret_op);
// create module to hold the previously created function
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op);
let printed_module =
super::print_mlir_operation_to_string(mlirModuleGetOperation(module));
let expected_module = "\
module {
func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>
return %0 : !FHE.eint<5>
}
}
";
assert_eq!(printed_module, expected_module);
}
}
}