diff --git a/compiler/Makefile b/compiler/Makefile index 673e3992a..b91770919 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -147,10 +147,13 @@ python-bindings: build-initialized # MLIRCAPIRegistration is currently required for linking rust bindings # This may fade if we can represent it somewhere else, mainly, we may be able to define a single # lib that the rust bindings need to link to, while that lib contains all necessary libs -rust-bindings: build-initialized concretecompiler +rust-bindings: build-initialized concretecompiler CAPI cmake --build $(BUILD_DIR) --target MLIRCAPIRegistration cd lib/Bindings/Rust && CONCRETE_COMPILER_BUILD_DIR=$(abspath $(BUILD_DIR)) cargo build --release +CAPI: + cmake --build $(BUILD_DIR) --target CONCRETELANGCAPIFHE CONCRETELANGCAPIFHELINALG + clientlib: build-initialized cmake --build $(BUILD_DIR) --target ConcretelangClientLib diff --git a/compiler/lib/Bindings/Rust/api.h b/compiler/lib/Bindings/Rust/api.h index a4d0bd5a9..f6c766464 100644 --- a/compiler/lib/Bindings/Rust/api.h +++ b/compiler/lib/Bindings/Rust/api.h @@ -3,6 +3,7 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include #include #include #include diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs index fb8948385..de9d5111a 100644 --- a/compiler/lib/Bindings/Rust/build.rs +++ b/compiler/lib/Bindings/Rust/build.rs @@ -230,6 +230,9 @@ fn run() -> Result<(), Box> { if let Some(name) = get_system_libcpp() { println!("cargo:rustc-link-lib={}", name); } + // concrete-compiler libs + println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHE"); + println!("cargo:rustc-link-lib=static=FHEDialect"); println!("cargo:rerun-if-changed=api.h"); bindgen::builder() diff --git a/compiler/lib/Bindings/Rust/src/lib.rs b/compiler/lib/Bindings/Rust/src/lib.rs index f8b2818e8..7dd438b2e 100644 --- a/compiler/lib/Bindings/Rust/src/lib.rs +++ b/compiler/lib/Bindings/Rust/src/lib.rs @@ -2,6 +2,7 @@ #![allow(non_camel_case_types)] #![allow(non_snake_case)] +use std::ffi::CString; use std::ops::AddAssign; include!(concat!(env!("OUT_DIR"), "/bindings.rs")); @@ -20,21 +21,208 @@ pub fn print_mlir_operation_to_string(op: MlirOperation) -> String { let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void; unsafe { - mlirOperationPrint( - op, + mlirOperationPrint(op, Some(mlir_rust_string_receiver_callback), receiver_ptr); + } + + rust_string +} + +pub fn print_mlir_type_to_string(mlir_type: MlirType) -> String { + let mut rust_string = String::default(); + let receiver_ptr = (&mut rust_string) as *mut String as *mut ::std::os::raw::c_void; + + unsafe { + mlirTypePrint( + mlir_type, Some(mlir_rust_string_receiver_callback), - receiver_ptr + receiver_ptr, ); } rust_string } +/// Returns a function operation with a region that contains a block. +/// +/// The function would be defined using the provided input and output types. The main block of the +/// function can be later fetched, from which we can get function arguments, and it will be where +/// we append operations. +/// +/// # Examples +/// ``` +/// use concrete_compiler_rust::*; +/// unsafe{ +/// let context = mlirContextCreate(); +/// mlirRegisterAllDialects(context); +/// +/// // input/output types +/// let func_input_types = [ +/// mlirIntegerTypeGet(context, 64), +/// mlirIntegerTypeGet(context, 64), +/// ]; +/// let func_output_types = [mlirIntegerTypeGet(context, 64)]; +/// +/// let func_op = create_func_with_block( +/// context, +/// "test", +/// func_input_types.as_slice(), +/// func_output_types.as_slice(), +/// ); +/// +/// // we can fetch the main block of the function from the function region +/// let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op)); +/// // we can get arguments to later be used as operands to other operations +/// let func_args = [ +/// mlirBlockGetArgument(func_block, 0), +/// mlirBlockGetArgument(func_block, 1), +/// ]; +/// // to add an operation to the function, we will append it to the main block +/// let addi_op = create_addi_op(context, func_args[0], func_args[1]); +/// mlirBlockAppendOwnedOperation(func_block, addi_op); +/// } +/// ``` +/// +pub fn create_func_with_block( + context: MlirContext, + func_name: &str, + func_input_types: &[MlirType], + func_output_types: &[MlirType], +) -> MlirOperation { + unsafe { + // create the main block of the function + let location = mlirLocationUnknownGet(context); + let func_block = mlirBlockCreate( + func_input_types.len().try_into().unwrap(), + func_input_types.as_ptr(), + &location, + ); + + // create region to hold the previously created block + let func_region = mlirRegionCreate(); + mlirRegionAppendOwnedBlock(func_region, func_block); + + // create function to hold the previously created region + let location = mlirLocationUnknownGet(context); + let func_str = CString::new("func.func").unwrap(); + let mut func_op_state = + mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location); + mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr()); + // set function attributes + let func_type_str = CString::new("function_type").unwrap(); + let sym_name_str = CString::new("sym_name").unwrap(); + let func_name_str = CString::new(func_name).unwrap(); + let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet( + context, + func_input_types.len().try_into().unwrap(), + func_input_types.as_ptr(), + func_output_types.len().try_into().unwrap(), + func_output_types.as_ptr(), + )); + let sym_name_attr = mlirStringAttrGet( + context, + mlirStringRefCreateFromCString(func_name_str.as_ptr()), + ); + mlirOperationStateAddAttributes( + &mut func_op_state, + 2, + [ + // func type + mlirNamedAttributeGet( + mlirIdentifierGet( + context, + mlirStringRefCreateFromCString(func_type_str.as_ptr()), + ), + func_type_attr, + ), + // func name + mlirNamedAttributeGet( + mlirIdentifierGet( + context, + mlirStringRefCreateFromCString(sym_name_str.as_ptr()), + ), + sym_name_attr, + ), + ] + .as_ptr(), + ); + let func_op = mlirOperationCreate(&mut func_op_state); + + func_op + } +} + +/// Generic function to create an MLIR operation. +/// +/// Create an MLIR operation based on its mnemonic (e.g. addi), it's operands, result types, and +/// attributes. Result types can be inferred automatically if the operation itself supports that. +pub fn create_op( + context: MlirContext, + mnemonic: &str, + operands: &[MlirValue], + results: &[MlirType], + attrs: &[MlirNamedAttribute], + auto_result_type_inference: bool, +) -> MlirOperation { + let op_mnemonic = CString::new(mnemonic).unwrap(); + unsafe { + let location = mlirLocationUnknownGet(context); + let mut op_state = mlirOperationStateGet( + mlirStringRefCreateFromCString(op_mnemonic.as_ptr()), + location, + ); + mlirOperationStateAddOperands( + &mut op_state, + operands.len().try_into().unwrap(), + operands.as_ptr(), + ); + mlirOperationStateAddAttributes( + &mut op_state, + attrs.len().try_into().unwrap(), + attrs.as_ptr(), + ); + if auto_result_type_inference { + mlirOperationStateEnableResultTypeInference(&mut op_state); + } else { + mlirOperationStateAddResults( + &mut op_state, + results.len().try_into().unwrap(), + results.as_ptr(), + ); + } + mlirOperationCreate(&mut op_state) + } +} + +pub fn create_fhe_add_eint_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(lhs)]; + // infer result type from operands + create_op( + context, + "FHE.add_eint", + &[lhs, rhs], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_addi_op(context: MlirContext, lhs: MlirValue, rhs: MlirValue) -> MlirOperation { + create_op(context, "arith.addi", &[lhs, rhs], &[], &[], true) +} + +pub fn create_ret_op(context: MlirContext, ret_value: MlirValue) -> MlirOperation { + create_op(context, "func.return", &[ret_value], &[], &[], false) +} #[cfg(test)] mod concrete_compiler_tests { use super::*; - use std::ffi::CString; #[test] fn test_function_type() { @@ -80,102 +268,41 @@ mod concrete_compiler_tests { unsafe { let context = mlirContextCreate(); mlirRegisterAllDialects(context); - let location = mlirLocationUnknownGet(context); // input/output types let func_input_types = [ mlirIntegerTypeGet(context, 64), mlirIntegerTypeGet(context, 64), ]; - let func_output_type = [mlirIntegerTypeGet(context, 64)]; + let func_output_types = [mlirIntegerTypeGet(context, 64)]; - // create the main block of the function - let func_block = mlirBlockCreate(2, func_input_types.as_ptr(), &location); - - let location = mlirLocationUnknownGet(context); - // create addi operation and append it to the block - let addi_str = CString::new("arith.addi").unwrap(); - let mut addi_op_state = - mlirOperationStateGet(mlirStringRefCreateFromCString(addi_str.as_ptr()), location); - mlirOperationStateAddOperands( - &mut addi_op_state, - 2, - [ - mlirBlockGetArgument(func_block, 0), - mlirBlockGetArgument(func_block, 1), - ] - .as_ptr(), + let func_op = create_func_with_block( + context, + "test", + func_input_types.as_slice(), + func_output_types.as_slice(), ); - mlirOperationStateEnableResultTypeInference(&mut addi_op_state); - let addi_op = mlirOperationCreate(&mut addi_op_state); + + let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op)); + let func_args = [ + mlirBlockGetArgument(func_block, 0), + mlirBlockGetArgument(func_block, 1), + ]; + // create addi operation and append it to the block + let addi_op = create_addi_op(context, func_args[0], func_args[1]); mlirBlockAppendOwnedOperation(func_block, addi_op); - // create return operation and append it to the block - let ret_str = CString::new("func.return").unwrap(); - let mut ret_op_state = - mlirOperationStateGet(mlirStringRefCreateFromCString(ret_str.as_ptr()), location); - mlirOperationStateAddOperands( - &mut ret_op_state, - 1, - [mlirOperationGetResult(addi_op, 0)].as_ptr(), - ); - let ret_op = mlirOperationCreate(&mut ret_op_state); + // create ret operation and append it to the block + let ret_op = create_ret_op(context, mlirOperationGetResult(addi_op, 0)); mlirBlockAppendOwnedOperation(func_block, ret_op); - // create region to hold the previously created block - let func_region = mlirRegionCreate(); - mlirRegionAppendOwnedBlock(func_region, func_block); - - // create function to hold the previously created region - let func_str = CString::new("func.func").unwrap(); - let mut func_op_state = - mlirOperationStateGet(mlirStringRefCreateFromCString(func_str.as_ptr()), location); - mlirOperationStateAddOwnedRegions(&mut func_op_state, 1, [func_region].as_ptr()); - // set function attributes - let func_type_str = CString::new("function_type").unwrap(); - let sym_name_str = CString::new("sym_name").unwrap(); - let func_name_str = CString::new("test").unwrap(); - let func_type_attr = mlirTypeAttrGet(mlirFunctionTypeGet( - context, - 2, - func_input_types.as_ptr(), - 1, - func_output_type.as_ptr(), - )); - let sym_name_attr = mlirStringAttrGet( - context, - mlirStringRefCreateFromCString(func_name_str.as_ptr()), - ); - mlirOperationStateAddAttributes( - &mut func_op_state, - 2, - [ - // func type - mlirNamedAttributeGet( - mlirIdentifierGet( - context, - mlirStringRefCreateFromCString(func_type_str.as_ptr()), - ), - func_type_attr, - ), - // func name - mlirNamedAttributeGet( - mlirIdentifierGet( - context, - mlirStringRefCreateFromCString(sym_name_str.as_ptr()), - ), - sym_name_attr, - ), - ] - .as_ptr(), - ); - let func_op = mlirOperationCreate(&mut func_op_state); - // create module to hold the previously created function + let location = mlirLocationUnknownGet(context); let module = mlirModuleCreateEmpty(location); mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op); - let printed_module = super::print_mlir_operation_to_string(mlirModuleGetOperation(module)); + let printed_module = + super::print_mlir_operation_to_string(mlirModuleGetOperation(module)); let expected_module = "\ module { func.func @test(%arg0: i64, %arg1: i64) -> i64 { @@ -185,11 +312,95 @@ module { } "; assert_eq!( - printed_module, - expected_module, + printed_module, expected_module, "left: \n{}, right: \n{}", - printed_module, - expected_module); + printed_module, expected_module + ); + } + } + + #[test] + fn test_invalid_fhe_eint_type() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + let invalid_eint = fheEncryptedIntegerTypeGetChecked(context, 0); + assert!(invalid_eint.isError); + } + } + + #[test] + fn test_valid_fhe_eint_type() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5); + assert!(!eint_or_error.isError); + let eint = eint_or_error.type_; + let printed_eint = super::print_mlir_type_to_string(eint); + let expected_eint = "!FHE.eint<5>"; + assert_eq!(printed_eint, expected_eint); + } + } + + #[test] + fn test_fhe_func() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 5-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5); + assert!(!eint_or_error.isError); + let eint = eint_or_error.type_; + + // set input/output types of the FHE circuit + let func_input_types = [eint, eint]; + let func_output_types = [eint]; + + // create the func operation + let func_op = create_func_with_block( + context, + "main", + func_input_types.as_slice(), + func_output_types.as_slice(), + ); + let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op)); + let func_args = [ + mlirBlockGetArgument(func_block, 0), + mlirBlockGetArgument(func_block, 1), + ]; + + // create an FHE add_eint op and append it to the function block + let add_eint_op = create_fhe_add_eint_op(context, func_args[0], func_args[1]); + mlirBlockAppendOwnedOperation(func_block, add_eint_op); + + // create ret operation and append it to the block + let ret_op = create_ret_op(context, mlirOperationGetResult(add_eint_op, 0)); + mlirBlockAppendOwnedOperation(func_block, ret_op); + + // create module to hold the previously created function + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op); + + let printed_module = + super::print_mlir_operation_to_string(mlirModuleGetOperation(module)); + let expected_module = "\ +module { + func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5> + return %0 : !FHE.eint<5> + } +} +"; + assert_eq!(printed_module, expected_module); } } }