From 0a57af37af2ec90d3af3b5dcf42550651fbb6b0c Mon Sep 17 00:00:00 2001 From: youben11 Date: Fri, 11 Nov 2022 20:10:49 +0100 Subject: [PATCH] feat(rust): add API for FHEDialect's op creation --- .../concretelang-c/Dialect/FHELinalg.h | 2 - compiler/lib/Bindings/Rust/api.h | 1 + compiler/lib/Bindings/Rust/build.rs | 2 + compiler/lib/Bindings/Rust/src/fhe.rs | 22 +- compiler/lib/Bindings/Rust/src/fhelinalg.rs | 490 ++++++++++++++++++ compiler/lib/Bindings/Rust/src/lib.rs | 1 + 6 files changed, 510 insertions(+), 8 deletions(-) create mode 100644 compiler/lib/Bindings/Rust/src/fhelinalg.rs diff --git a/compiler/include/concretelang-c/Dialect/FHELinalg.h b/compiler/include/concretelang-c/Dialect/FHELinalg.h index 1949d9744..7a6a5b519 100644 --- a/compiler/include/concretelang-c/Dialect/FHELinalg.h +++ b/compiler/include/concretelang-c/Dialect/FHELinalg.h @@ -8,8 +8,6 @@ #include "mlir-c/IR.h" #include "mlir-c/Registration.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/Support/LLVM.h" #ifdef __cplusplus extern "C" { diff --git a/compiler/lib/Bindings/Rust/api.h b/compiler/lib/Bindings/Rust/api.h index f6c766464..e3e2ac2ca 100644 --- a/compiler/lib/Bindings/Rust/api.h +++ b/compiler/lib/Bindings/Rust/api.h @@ -4,6 +4,7 @@ // for license information. #include +#include #include #include #include diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs index de9d5111a..3a632724d 100644 --- a/compiler/lib/Bindings/Rust/build.rs +++ b/compiler/lib/Bindings/Rust/build.rs @@ -233,6 +233,8 @@ fn run() -> Result<(), Box> { // concrete-compiler libs println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHE"); println!("cargo:rustc-link-lib=static=FHEDialect"); + println!("cargo:rustc-link-lib=static=CONCRETELANGCAPIFHELINALG"); + println!("cargo:rustc-link-lib=static=FHELinalgDialect"); println!("cargo:rerun-if-changed=api.h"); bindgen::builder() diff --git a/compiler/lib/Bindings/Rust/src/fhe.rs b/compiler/lib/Bindings/Rust/src/fhe.rs index 1656163a1..da4bbc584 100644 --- a/compiler/lib/Bindings/Rust/src/fhe.rs +++ b/compiler/lib/Bindings/Rust/src/fhe.rs @@ -135,13 +135,13 @@ pub fn create_fhe_apply_lut_op( context: MlirContext, eint: MlirValue, lut: MlirValue, - out_type: MlirType, + result_type: MlirType, ) -> MlirOperation { create_op( context, "FHE.apply_lookup_table", &[eint, lut], - [out_type].as_slice(), + [result_type].as_slice(), &[], false, ) @@ -177,16 +177,26 @@ pub fn create_fhe_to_unsigned_op(context: MlirContext, esint: MlirValue) -> Mlir } } -pub fn create_fhe_zero_eint_op(context: MlirContext, out_type: MlirType) -> MlirOperation { - create_op(context, "FHE.zero", &[], [out_type].as_slice(), &[], false) +pub fn create_fhe_zero_eint_op(context: MlirContext, result_type: MlirType) -> MlirOperation { + create_op( + context, + "FHE.zero", + &[], + [result_type].as_slice(), + &[], + false, + ) } -pub fn create_fhe_zero_eint_tensor_op(context: MlirContext, out_type: MlirType) -> MlirOperation { +pub fn create_fhe_zero_eint_tensor_op( + context: MlirContext, + result_type: MlirType, +) -> MlirOperation { create_op( context, "FHE.zero_tensor", &[], - [out_type].as_slice(), + [result_type].as_slice(), &[], false, ) diff --git a/compiler/lib/Bindings/Rust/src/fhelinalg.rs b/compiler/lib/Bindings/Rust/src/fhelinalg.rs new file mode 100644 index 000000000..6d7a2f63c --- /dev/null +++ b/compiler/lib/Bindings/Rust/src/fhelinalg.rs @@ -0,0 +1,490 @@ +//! FHELinalg dialect module + +use crate::mlir::*; + +pub fn create_fhelinalg_add_eint_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.add_eint", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_add_eint_int_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.add_eint_int", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_sub_eint_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.sub_eint", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_sub_eint_int_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.sub_eint_int", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_sub_int_eint_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.sub_int_eint", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_negate_eint_op( + context: MlirContext, + eint_tensor: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(eint_tensor)]; + // infer result type from operands + create_op( + context, + "FHELinalg.neg_eint", + &[eint_tensor], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhelinalg_mul_eint_int_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.mul_eint_int", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_apply_lut_op( + context: MlirContext, + eint_tensor: MlirValue, + lut: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.apply_lookup_table", + &[eint_tensor, lut], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_apply_multi_lut_op( + context: MlirContext, + eint_tensor: MlirValue, + lut: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.apply_multi_lookup_table", + &[eint_tensor, lut], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_apply_mapped_lut_op( + context: MlirContext, + eint_tensor: MlirValue, + lut: MlirValue, + map: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.apply_mapped_lookup_table", + &[eint_tensor, lut, map], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_dot_eint_int_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.dot_eint_int", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_matmul_eint_int_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.matmul_eint_int", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_matmul_int_eint_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, + result_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHELinalg.matmul_int_eint", + &[lhs, rhs], + [result_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhelinalg_sum_op( + context: MlirContext, + eint_tensor: MlirValue, + axes: Option, + keep_dims: Option, + result_type: MlirType, +) -> MlirOperation { + let mut attrs: Vec = Vec::new(); + if axes.is_some() { + attrs.push(axes.unwrap()); + } + if keep_dims.is_some() { + attrs.push(keep_dims.unwrap()); + } + create_op( + context, + "FHELinalg.sum", + &[eint_tensor], + [result_type].as_slice(), + attrs.as_slice(), + false, + ) +} + +pub fn create_fhelinalg_concat_op( + context: MlirContext, + eint_tensor: MlirValue, + axis: Option, + result_type: MlirType, +) -> MlirOperation { + let mut attrs: Vec = Vec::new(); + if axis.is_some() { + attrs.push(axis.unwrap()); + } + create_op( + context, + "FHELinalg.concat", + &[eint_tensor], + [result_type].as_slice(), + &attrs, + false, + ) +} + +pub fn create_fhelinalg_conv2d_op( + context: MlirContext, + input: MlirValue, + weight: MlirValue, + bias: Option, + padding: Option, + strides: Option, + dilations: Option, + group: Option, + result_type: MlirType, +) -> MlirOperation { + let mut operands = Vec::new(); + operands.push(input); + operands.push(weight); + if bias.is_some() { + operands.push(bias.unwrap()); + } + let mut attrs = Vec::new(); + if padding.is_some() { + attrs.push(padding.unwrap()); + } + if strides.is_some() { + attrs.push(strides.unwrap()); + } + if dilations.is_some() { + attrs.push(dilations.unwrap()); + } + if group.is_some() { + attrs.push(group.unwrap()); + } + create_op( + context, + "FHELinalg.conv2d", + &operands, + [result_type].as_slice(), + &attrs, + false, + ) +} + +pub fn create_fhelinalg_transpose_op( + context: MlirContext, + eint_tensor: MlirValue, + axes: Option, + result_type: MlirType, +) -> MlirOperation { + let mut attrs: Vec = Vec::new(); + if axes.is_some() { + attrs.push(axes.unwrap()); + } + create_op( + context, + "FHELinalg.transpose", + &[eint_tensor], + [result_type].as_slice(), + attrs.as_slice(), + false, + ) +} + +pub fn create_fhelinalg_from_element_op(context: MlirContext, element: MlirValue) -> MlirOperation { + unsafe { + let location = mlirLocationUnknownGet(context); + let shape: [i64; 1] = [1]; + let result_type = mlirRankedTensorTypeGetChecked( + location, + 1, + shape.as_ptr(), + mlirValueGetType(element), + mlirAttributeGetNull(), + ); + create_op( + context, + "FHELinalg.from_element", + &[element], + [result_type].as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhelinalg_to_signed_op( + context: MlirContext, + eint_tensor: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(eint_tensor)]; + // infer result type from operands + create_op( + context, + "FHELinalg.to_signed", + &[eint_tensor], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhelinalg_to_unsigned_op( + context: MlirContext, + esint_tensor: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(esint_tensor)]; + // infer result type from operands + create_op( + context, + "FHELinalg.to_unsigned", + &[esint_tensor], + results.as_slice(), + &[], + false, + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::fhe::*; + + #[test] + fn test_fhelinalg_func() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHELinalg dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create a 5-bit eint tensor type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5); + assert!(!eint_or_error.isError); + let eint = eint_or_error.type_; + let shape: [i64; 2] = [6, 73]; + let location = mlirLocationUnknownGet(context); + let eint_tensor = mlirRankedTensorTypeGetChecked( + location, + 2, + shape.as_ptr(), + eint, + mlirAttributeGetNull(), + ); + + // set input/output types of the FHE circuit + let func_input_types = [eint_tensor, eint_tensor]; + let func_output_types = [eint_tensor]; + + // create the func operation + let func_op = create_func_with_block( + context, + "main", + func_input_types.as_slice(), + func_output_types.as_slice(), + ); + let func_block = mlirRegionGetFirstBlock(mlirOperationGetFirstRegion(func_op)); + let func_args = [ + mlirBlockGetArgument(func_block, 0), + mlirBlockGetArgument(func_block, 1), + ]; + + // create an FHE add_eint op and append it to the function block + let add_eint_op = + create_fhelinalg_add_eint_op(context, func_args[0], func_args[1], eint_tensor); + mlirBlockAppendOwnedOperation(func_block, add_eint_op); + + // create ret operation and append it to the block + let ret_op = create_ret_op(context, mlirOperationGetResult(add_eint_op, 0)); + mlirBlockAppendOwnedOperation(func_block, ret_op); + + // create module to hold the previously created function + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + mlirBlockAppendOwnedOperation(mlirModuleGetBody(module), func_op); + + let printed_module = + super::print_mlir_operation_to_string(mlirModuleGetOperation(module)); + let expected_module = "\ +module { + func.func @main(%arg0: tensor<6x73x!FHE.eint<5>>, %arg1: tensor<6x73x!FHE.eint<5>>) -> tensor<6x73x!FHE.eint<5>> { + %0 = \"FHELinalg.add_eint\"(%arg0, %arg1) : (tensor<6x73x!FHE.eint<5>>, tensor<6x73x!FHE.eint<5>>) -> tensor<6x73x!FHE.eint<5>> + return %0 : tensor<6x73x!FHE.eint<5>> + } +} +"; + assert_eq!(printed_module, expected_module); + } + } + + #[test] + fn test_zero_tensor_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHELinalg dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create a 4-bit eint tensor type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 4); + assert!(!eint_or_error.isError); + let eint = eint_or_error.type_; + let shape: [i64; 3] = [60, 66, 73]; + let location = mlirLocationUnknownGet(context); + let eint_tensor = mlirRankedTensorTypeGetChecked( + location, + 3, + shape.as_ptr(), + eint, + mlirAttributeGetNull(), + ); + + let zero_op = create_fhe_zero_eint_tensor_op(context, eint_tensor); + let printed_op = print_mlir_operation_to_string(zero_op); + let expected_op = "%0 = \"FHE.zero_tensor\"() : () -> tensor<60x66x73x!FHE.eint<4>>\n"; + assert_eq!(printed_op, expected_op); + } + } +} diff --git a/compiler/lib/Bindings/Rust/src/lib.rs b/compiler/lib/Bindings/Rust/src/lib.rs index d574cff28..f345633db 100644 --- a/compiler/lib/Bindings/Rust/src/lib.rs +++ b/compiler/lib/Bindings/Rust/src/lib.rs @@ -1,2 +1,3 @@ pub mod fhe; +pub mod fhelinalg; pub mod mlir;