From 52d5d908bb9b7e57b4bc72f2df0fe4d8d21e65cd Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 15 Nov 2022 19:03:08 +0100 Subject: [PATCH] test(rust): complete tests of rust bindings updated also the API to make it easier to use by: - creating MLIR components from native rust types instead of require MLIR components in the API - adding helpers around the creation of standard dialects --- compiler/lib/Bindings/Rust/src/fhe.rs | 422 ++++++- compiler/lib/Bindings/Rust/src/fhelinalg.rs | 1152 +++++++++++++++++-- compiler/lib/Bindings/Rust/src/mlir.rs | 125 ++ 3 files changed, 1602 insertions(+), 97 deletions(-) diff --git a/compiler/lib/Bindings/Rust/src/fhe.rs b/compiler/lib/Bindings/Rust/src/fhe.rs index da4bbc584..4a35a474a 100644 --- a/compiler/lib/Bindings/Rust/src/fhe.rs +++ b/compiler/lib/Bindings/Rust/src/fhe.rs @@ -147,9 +147,51 @@ pub fn create_fhe_apply_lut_op( ) } +#[derive(Debug)] +pub enum FHEError { + InvalidFHEType, + InvalidWidth, +} + +pub fn convert_eint_to_esint_type( + context: MlirContext, + eint_type: MlirType, +) -> Result { + unsafe { + let width = fheTypeIntegerWidthGet(eint_type); + if width == 0 { + return Err(FHEError::InvalidFHEType); + } + let type_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, width); + if type_or_error.isError { + Err(FHEError::InvalidWidth) + } else { + Ok(type_or_error.type_) + } + } +} + +pub fn convert_esint_to_eint_type( + context: MlirContext, + esint_type: MlirType, +) -> Result { + unsafe { + let width = fheTypeIntegerWidthGet(esint_type); + if width == 0 { + return Err(FHEError::InvalidFHEType); + } + let type_or_error = fheEncryptedIntegerTypeGetChecked(context, width); + if type_or_error.isError { + Err(FHEError::InvalidWidth) + } else { + Ok(type_or_error.type_) + } + } +} + pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOperation { unsafe { - let results = [mlirValueGetType(eint)]; + let results = [convert_eint_to_esint_type(context, mlirValueGetType(eint)).unwrap()]; // infer result type from operands create_op( context, @@ -164,7 +206,7 @@ pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOpe pub fn create_fhe_to_unsigned_op(context: MlirContext, esint: MlirValue) -> MlirOperation { unsafe { - let results = [mlirValueGetType(esint)]; + let results = [convert_esint_to_eint_type(context, mlirValueGetType(esint)).unwrap()]; // infer result type from operands create_op( context, @@ -226,12 +268,34 @@ mod test { let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5); assert!(!eint_or_error.isError); let eint = eint_or_error.type_; + assert!(fheTypeIsAnEncryptedIntegerType(eint)); + assert!(!fheTypeIsAnEncryptedSignedIntegerType(eint)); + assert_eq!(fheTypeIntegerWidthGet(eint), 5); let printed_eint = super::print_mlir_type_to_string(eint); let expected_eint = "!FHE.eint<5>"; assert_eq!(printed_eint, expected_eint); } } + #[test] + fn test_valid_fhe_esint_type() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 5); + assert!(!esint_or_error.isError); + let esint = esint_or_error.type_; + assert!(fheTypeIsAnEncryptedSignedIntegerType(esint)); + assert!(!fheTypeIsAnEncryptedIntegerType(esint)); + assert_eq!(fheTypeIntegerWidthGet(esint), 5); + let printed_esint = super::print_mlir_type_to_string(esint); + let expected_esint = "!FHE.esint<5>"; + assert_eq!(printed_esint, expected_esint); + } + } + #[test] fn test_fhe_func() { unsafe { @@ -313,6 +377,37 @@ module { } } + #[test] + fn test_zero_tensor_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 4-bit eint tensor type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 4); + assert!(!eint_or_error.isError); + let eint = eint_or_error.type_; + let shape: [i64; 3] = [60, 66, 73]; + let location = mlirLocationUnknownGet(context); + let eint_tensor = mlirRankedTensorTypeGetChecked( + location, + 3, + shape.as_ptr(), + eint, + mlirAttributeGetNull(), + ); + + let zero_op = create_fhe_zero_eint_tensor_op(context, eint_tensor); + let printed_op = print_mlir_operation_to_string(zero_op); + let expected_op = "%0 = \"FHE.zero_tensor\"() : () -> tensor<60x66x73x!FHE.eint<4>>\n"; + assert_eq!(printed_op, expected_op); + } + } + #[test] fn test_add_eint_op() { unsafe { @@ -346,4 +441,327 @@ module { assert_eq!(printed_op, expected_op); } } + + #[test] + fn test_add_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // create an int via a constant op + let cst_op = create_constant_int_op(context, 73, 7); + mlirBlockAppendOwnedOperation(main_block, cst_op); + let int_value = mlirOperationGetResult(cst_op, 0); + // add eint int + let add_eint_int_op = create_fhe_add_eint_int_op(context, eint_value, int_value); + mlirBlockAppendOwnedOperation(main_block, add_eint_int_op); + + let printed_op = print_mlir_operation_to_string(add_eint_int_op); + let expected_op = + "%1 = \"FHE.add_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_sub_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // sub eint with itself + let sub_eint_op = create_fhe_sub_eint_op(context, eint_value, eint_value); + mlirBlockAppendOwnedOperation(main_block, sub_eint_op); + + let printed_op = print_mlir_operation_to_string(sub_eint_op); + let expected_op = + "%1 = \"FHE.sub_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_sub_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // create an int via a constant op + let cst_op = create_constant_int_op(context, 73, 7); + mlirBlockAppendOwnedOperation(main_block, cst_op); + let int_value = mlirOperationGetResult(cst_op, 0); + // sub eint int + let sub_eint_int_op = create_fhe_sub_eint_int_op(context, eint_value, int_value); + mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op); + + let printed_op = print_mlir_operation_to_string(sub_eint_int_op); + let expected_op = + "%1 = \"FHE.sub_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_sub_int_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // create an int via a constant op + let cst_op = create_constant_int_op(context, 73, 7); + mlirBlockAppendOwnedOperation(main_block, cst_op); + let int_value = mlirOperationGetResult(cst_op, 0); + // sub int eint + let sub_eint_int_op = create_fhe_sub_int_eint_op(context, int_value, eint_value); + mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op); + + let printed_op = print_mlir_operation_to_string(sub_eint_int_op); + let expected_op = + "%1 = \"FHE.sub_int_eint\"(%c-55_i7, %0) : (i7, !FHE.eint<6>) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_negate_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // negate eint + let neg_eint_op = create_fhe_negate_eint_op(context, eint_value); + mlirBlockAppendOwnedOperation(main_block, neg_eint_op); + + let printed_op = print_mlir_operation_to_string(neg_eint_op); + let expected_op = "%1 = \"FHE.neg_eint\"(%0) : (!FHE.eint<6>) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_mul_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // create an int via a constant op + let cst_op = create_constant_int_op(context, 73, 7); + mlirBlockAppendOwnedOperation(main_block, cst_op); + let int_value = mlirOperationGetResult(cst_op, 0); + // mul eint int + let mul_eint_int_op = create_fhe_mul_eint_int_op(context, eint_value, int_value); + mlirBlockAppendOwnedOperation(main_block, mul_eint_int_op); + + let printed_op = print_mlir_operation_to_string(mul_eint_int_op); + let expected_op = + "%1 = \"FHE.mul_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_to_signed_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // to signed + let to_signed_op = create_fhe_to_signed_op(context, eint_value); + mlirBlockAppendOwnedOperation(main_block, to_signed_op); + + let printed_op = print_mlir_operation_to_string(to_signed_op); + let expected_op = "%1 = \"FHE.to_signed\"(%0) : (!FHE.eint<6>) -> !FHE.esint<6>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_to_unsigned_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit esint type + let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 6); + assert!(!esint_or_error.isError); + let esint6_type = esint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, esint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let esint_value = mlirOperationGetResult(zero_op, 0); + // to unsigned + let to_unsigned_op = create_fhe_to_unsigned_op(context, esint_value); + mlirBlockAppendOwnedOperation(main_block, to_unsigned_op); + + let printed_op = print_mlir_operation_to_string(to_unsigned_op); + let expected_op = "%1 = \"FHE.to_unsigned\"(%0) : (!FHE.esint<6>) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_apply_lut_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // create an lut + let table: [i64; 64] = [0; 64]; + let constant_lut_op = create_constant_flat_tensor_op(context, &table, 64); + mlirBlockAppendOwnedOperation(main_block, constant_lut_op); + let lut = mlirOperationGetResult(constant_lut_op, 0); + // LUT op + let apply_lut_op = create_fhe_apply_lut_op(context, eint_value, lut, eint6_type); + mlirBlockAppendOwnedOperation(main_block, apply_lut_op); + + let printed_op = print_mlir_operation_to_string(apply_lut_op); + let expected_op = "%1 = \"FHE.apply_lookup_table\"(%0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } } diff --git a/compiler/lib/Bindings/Rust/src/fhelinalg.rs b/compiler/lib/Bindings/Rust/src/fhelinalg.rs index 6d7a2f63c..63471e82f 100644 --- a/compiler/lib/Bindings/Rust/src/fhelinalg.rs +++ b/compiler/lib/Bindings/Rust/src/fhelinalg.rs @@ -1,6 +1,10 @@ //! FHELinalg dialect module -use crate::mlir::*; +use crate::{ + fhe::{convert_eint_to_esint_type, convert_esint_to_eint_type}, + mlir::*, +}; +use std::ffi::CString; pub fn create_fhelinalg_add_eint_op( context: MlirContext, @@ -216,45 +220,81 @@ pub fn create_fhelinalg_matmul_int_eint_op( pub fn create_fhelinalg_sum_op( context: MlirContext, eint_tensor: MlirValue, - axes: Option, - keep_dims: Option, + axes: Option<&[i64]>, + keep_dims: Option, result_type: MlirType, ) -> MlirOperation { - let mut attrs: Vec = Vec::new(); - if axes.is_some() { - attrs.push(axes.unwrap()); + unsafe { + let mut attrs: Vec = Vec::new(); + match axes { + Some(value) => { + let axes_str = CString::new("axes").unwrap(); + let axes_attrs: Vec = value + .into_iter() + .map(|value| mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), *value)) + .collect(); + attrs.push(mlirNamedAttributeGet( + mlirIdentifierGet(context, mlirStringRefCreateFromCString(axes_str.as_ptr())), + mlirArrayAttrGet( + context, + value.len().try_into().unwrap(), + axes_attrs.as_ptr(), + ), + )); + } + None => (), + } + match keep_dims { + Some(value) => { + let keep_dims_str = CString::new("keep_dims").unwrap(); + attrs.push(mlirNamedAttributeGet( + mlirIdentifierGet( + context, + mlirStringRefCreateFromCString(keep_dims_str.as_ptr()), + ), + mlirBoolAttrGet(context, value.into()), + )); + } + None => (), + } + create_op( + context, + "FHELinalg.sum", + &[eint_tensor], + [result_type].as_slice(), + attrs.as_slice(), + false, + ) } - if keep_dims.is_some() { - attrs.push(keep_dims.unwrap()); - } - create_op( - context, - "FHELinalg.sum", - &[eint_tensor], - [result_type].as_slice(), - attrs.as_slice(), - false, - ) } pub fn create_fhelinalg_concat_op( context: MlirContext, - eint_tensor: MlirValue, - axis: Option, + eint_tensors: &[MlirValue], + axis: Option, result_type: MlirType, ) -> MlirOperation { - let mut attrs: Vec = Vec::new(); - if axis.is_some() { - attrs.push(axis.unwrap()); + unsafe { + let mut attrs: Vec = Vec::new(); + match axis { + Some(value) => { + let axis_str = CString::new("axis").unwrap(); + attrs.push(mlirNamedAttributeGet( + mlirIdentifierGet(context, mlirStringRefCreateFromCString(axis_str.as_ptr())), + mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), value.into()), + )); + } + None => (), + } + create_op( + context, + "FHELinalg.concat", + eint_tensors, + [result_type].as_slice(), + &attrs, + false, + ) } - create_op( - context, - "FHELinalg.concat", - &[eint_tensor], - [result_type].as_slice(), - &attrs, - false, - ) } pub fn create_fhelinalg_conv2d_op( @@ -262,59 +302,143 @@ pub fn create_fhelinalg_conv2d_op( input: MlirValue, weight: MlirValue, bias: Option, - padding: Option, - strides: Option, - dilations: Option, - group: Option, + padding: Option<&[i64]>, + strides: Option<&[i64]>, + dilations: Option<&[i64]>, + group: Option, result_type: MlirType, ) -> MlirOperation { - let mut operands = Vec::new(); - operands.push(input); - operands.push(weight); - if bias.is_some() { - operands.push(bias.unwrap()); + unsafe { + let mut operands = Vec::new(); + operands.push(input); + operands.push(weight); + match bias { + Some(value) => operands.push(value), + None => (), + } + let mut attrs = Vec::new(); + match padding { + Some(value) => { + let padding_str = CString::new("padding").unwrap(); + attrs.push(mlirNamedAttributeGet( + mlirIdentifierGet( + context, + mlirStringRefCreateFromCString(padding_str.as_ptr()), + ), + mlirDenseElementsAttrInt64Get( + mlirRankedTensorTypeGet( + 1, + [value.len() as i64].as_ptr(), + mlirIntegerTypeGet(context, 64), + mlirAttributeGetNull(), + ), + value.len() as isize, + value.as_ptr(), + ), + )); + } + None => (), + } + match strides { + Some(value) => { + let strides_str = CString::new("strides").unwrap(); + attrs.push(mlirNamedAttributeGet( + mlirIdentifierGet( + context, + mlirStringRefCreateFromCString(strides_str.as_ptr()), + ), + mlirDenseElementsAttrInt64Get( + mlirRankedTensorTypeGet( + 1, + [value.len() as i64].as_ptr(), + mlirIntegerTypeGet(context, 64), + mlirAttributeGetNull(), + ), + value.len() as isize, + value.as_ptr(), + ), + )); + } + None => (), + } + match dilations { + Some(value) => { + let dilations_str = CString::new("dilations").unwrap(); + attrs.push(mlirNamedAttributeGet( + mlirIdentifierGet( + context, + mlirStringRefCreateFromCString(dilations_str.as_ptr()), + ), + mlirDenseElementsAttrInt64Get( + mlirRankedTensorTypeGet( + 1, + [value.len() as i64].as_ptr(), + mlirIntegerTypeGet(context, 64), + mlirAttributeGetNull(), + ), + value.len() as isize, + value.as_ptr(), + ), + )); + } + None => (), + } + match group { + Some(value) => { + let group_str = CString::new("group").unwrap(); + attrs.push(mlirNamedAttributeGet( + mlirIdentifierGet(context, mlirStringRefCreateFromCString(group_str.as_ptr())), + mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), value.into()), + )); + } + None => (), + } + create_op( + context, + "FHELinalg.conv2d", + &operands, + [result_type].as_slice(), + &attrs, + false, + ) } - let mut attrs = Vec::new(); - if padding.is_some() { - attrs.push(padding.unwrap()); - } - if strides.is_some() { - attrs.push(strides.unwrap()); - } - if dilations.is_some() { - attrs.push(dilations.unwrap()); - } - if group.is_some() { - attrs.push(group.unwrap()); - } - create_op( - context, - "FHELinalg.conv2d", - &operands, - [result_type].as_slice(), - &attrs, - false, - ) } pub fn create_fhelinalg_transpose_op( context: MlirContext, eint_tensor: MlirValue, - axes: Option, + axes: Option<&[i64]>, result_type: MlirType, ) -> MlirOperation { - let mut attrs: Vec = Vec::new(); - if axes.is_some() { - attrs.push(axes.unwrap()); + unsafe { + let mut attrs: Vec = Vec::new(); + match axes { + Some(value) => { + let axes_str = CString::new("axes").unwrap(); + let axes_attrs: Vec = value + .into_iter() + .map(|value| mlirIntegerAttrGet(mlirIntegerTypeGet(context, 64), *value)) + .collect(); + attrs.push(mlirNamedAttributeGet( + mlirIdentifierGet(context, mlirStringRefCreateFromCString(axes_str.as_ptr())), + mlirArrayAttrGet( + context, + value.len().try_into().unwrap(), + axes_attrs.as_ptr(), + ), + )); + } + None => (), + } + create_op( + context, + "FHELinalg.transpose", + &[eint_tensor], + [result_type].as_slice(), + attrs.as_slice(), + false, + ) } - create_op( - context, - "FHELinalg.transpose", - &[eint_tensor], - [result_type].as_slice(), - attrs.as_slice(), - false, - ) } pub fn create_fhelinalg_from_element_op(context: MlirContext, element: MlirValue) -> MlirOperation { @@ -344,8 +468,17 @@ pub fn create_fhelinalg_to_signed_op( eint_tensor: MlirValue, ) -> MlirOperation { unsafe { - let results = [mlirValueGetType(eint_tensor)]; - // infer result type from operands + let input_type = mlirValueGetType(eint_tensor); + let rank = mlirShapedTypeGetRank(input_type); + let shape: Vec = (0i64..rank) + .map(|dim| mlirShapedTypeGetDimSize(input_type, dim.try_into().unwrap())) + .collect(); + let results = [mlirRankedTensorTypeGet( + rank.try_into().unwrap(), + shape.as_ptr(), + convert_eint_to_esint_type(context, mlirShapedTypeGetElementType(input_type)).unwrap(), + mlirAttributeGetNull(), + )]; create_op( context, "FHELinalg.to_signed", @@ -362,8 +495,17 @@ pub fn create_fhelinalg_to_unsigned_op( esint_tensor: MlirValue, ) -> MlirOperation { unsafe { - let results = [mlirValueGetType(esint_tensor)]; - // infer result type from operands + let input_type = mlirValueGetType(esint_tensor); + let rank = mlirShapedTypeGetRank(input_type); + let shape: Vec = (0i64..rank) + .map(|dim| mlirShapedTypeGetDimSize(input_type, dim.try_into().unwrap())) + .collect(); + let results = [mlirRankedTensorTypeGet( + rank.try_into().unwrap(), + shape.as_ptr(), + convert_esint_to_eint_type(context, mlirShapedTypeGetElementType(input_type)).unwrap(), + mlirAttributeGetNull(), + )]; create_op( context, "FHELinalg.to_unsigned", @@ -380,6 +522,36 @@ mod test { use super::*; use crate::fhe::*; + fn get_eint_tensor_type(context: MlirContext, shape: &[i64], width: u32) -> MlirType { + unsafe { + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, width); + assert!(!eint_or_error.isError); + let eint = eint_or_error.type_; + mlirRankedTensorTypeGetChecked( + mlirLocationUnknownGet(context), + shape.len().try_into().unwrap(), + shape.as_ptr(), + eint, + mlirAttributeGetNull(), + ) + } + } + + fn get_esint_tensor_type(context: MlirContext, shape: &[i64], width: u32) -> MlirType { + unsafe { + let eint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, width); + assert!(!eint_or_error.isError); + let eint = eint_or_error.type_; + mlirRankedTensorTypeGetChecked( + mlirLocationUnknownGet(context), + shape.len().try_into().unwrap(), + shape.as_ptr(), + eint, + mlirAttributeGetNull(), + ) + } + } + #[test] fn test_fhelinalg_func() { unsafe { @@ -454,36 +626,826 @@ module { } #[test] - fn test_zero_tensor_op() { + fn test_add_eint_op() { unsafe { let context = mlirContextCreate(); mlirRegisterAllDialects(context); - - // register the FHELinalg dialect + // register the FHE dialect let fhe_handle = mlirGetDialectHandle__fhe__(); mlirDialectHandleLoadDialect(fhe_handle, context); - // register the FHELinalg dialect let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); mlirDialectHandleLoadDialect(fhelinalg_handle, context); - // create a 4-bit eint tensor type - let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 4); - assert!(!eint_or_error.isError); - let eint = eint_or_error.type_; - let shape: [i64; 3] = [60, 66, 73]; + // create module for ops let location = mlirLocationUnknownGet(context); - let eint_tensor = mlirRankedTensorTypeGetChecked( - location, - 3, - shape.as_ptr(), - eint, - mlirAttributeGetNull(), - ); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); - let zero_op = create_fhe_zero_eint_tensor_op(context, eint_tensor); - let printed_op = print_mlir_operation_to_string(zero_op); - let expected_op = "%0 = \"FHE.zero_tensor\"() : () -> tensor<60x66x73x!FHE.eint<4>>\n"; + // create a 4-bit eint tensor type + let eint_tensor_type = get_eint_tensor_type(context, &[5, 7], 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create add_eint op + let add_eint_op = create_fhelinalg_add_eint_op( + context, + eint_tensor_value, + eint_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, add_eint_op); + + let printed_op = print_mlir_operation_to_string(add_eint_op); + let expected_op = "%1 = \"FHELinalg.add_eint\"(%0, %0) : (tensor<5x7x!FHE.eint<4>>, tensor<5x7x!FHE.eint<4>>) -> tensor<5x7x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_add_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [73, 1]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &shape, &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let int_tensor_value = mlirOperationGetResult(constant_int_tensor_op, 0); + // create add_eint_int op + let add_eint_int_op = create_fhelinalg_add_eint_int_op( + context, + eint_tensor_value, + int_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, add_eint_int_op); + + let printed_op = print_mlir_operation_to_string(add_eint_int_op); + let expected_op = "%1 = \"FHELinalg.add_eint_int\"(%0, %cst) : (tensor<73x1x!FHE.eint<4>>, tensor<73x1xi5>) -> tensor<73x1x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_sub_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let eint_tensor_type = get_eint_tensor_type(context, &[5, 7], 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create sub_eint op + let sub_eint_op = create_fhelinalg_sub_eint_op( + context, + eint_tensor_value, + eint_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, sub_eint_op); + + let printed_op = print_mlir_operation_to_string(sub_eint_op); + let expected_op = "%1 = \"FHELinalg.sub_eint\"(%0, %0) : (tensor<5x7x!FHE.eint<4>>, tensor<5x7x!FHE.eint<4>>) -> tensor<5x7x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_sub_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [2, 4, 6, 9, 13, 100]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &shape, &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let int_tensor_value = mlirOperationGetResult(constant_int_tensor_op, 0); + // create sub_eint_int op + let sub_eint_int_op = create_fhelinalg_sub_eint_int_op( + context, + eint_tensor_value, + int_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op); + + let printed_op = print_mlir_operation_to_string(sub_eint_int_op); + let expected_op = "%1 = \"FHELinalg.sub_eint_int\"(%0, %cst) : (tensor<2x4x6x9x13x100x!FHE.eint<4>>, tensor<2x4x6x9x13x100xi5>) \ +-> tensor<2x4x6x9x13x100x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_sub_int_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [1]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &shape, &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let int_tensor_value = mlirOperationGetResult(constant_int_tensor_op, 0); + // create sub_int_eint op + let sub_int_eint_op = create_fhelinalg_sub_int_eint_op( + context, + eint_tensor_value, + int_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, sub_int_eint_op); + + let printed_op = print_mlir_operation_to_string(sub_int_eint_op); + let expected_op = "%2 = \"FHELinalg.sub_int_eint\"(%0, %1) : (tensor<1x!FHE.eint<4>>, tensor<1xi5>) -> tensor<1x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_neg_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [16]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create neg_eint op + let neg_eint_op = create_fhelinalg_negate_eint_op(context, eint_tensor_value); + mlirBlockAppendOwnedOperation(main_block, neg_eint_op); + + let printed_op = print_mlir_operation_to_string(neg_eint_op); + let expected_op = "%1 = \"FHELinalg.neg_eint\"(%0) : (tensor<16x!FHE.eint<4>>) -> tensor<16x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_mul_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [100]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &shape, &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let int_tensor_value = mlirOperationGetResult(constant_int_tensor_op, 0); + // create mul_eint_int op + let mul_eint_int_op = create_fhelinalg_mul_eint_int_op( + context, + eint_tensor_value, + int_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, mul_eint_int_op); + + let printed_op = print_mlir_operation_to_string(mul_eint_int_op); + let expected_op = "%1 = \"FHELinalg.mul_eint_int\"(%0, %cst) : (tensor<100x!FHE.eint<4>>, tensor<100xi5>) \ +-> tensor<100x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_apply_lut_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape_tensor = [4, 4, 4]; + let eint_tensor_type = get_eint_tensor_type(context, &shape_tensor, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &[16], &[0], 64); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let lut = mlirOperationGetResult(constant_int_tensor_op, 0); + // create lut op + let lut_op = + create_fhelinalg_apply_lut_op(context, eint_tensor_value, lut, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, lut_op); + + let printed_op = print_mlir_operation_to_string(lut_op); + let expected_op = "%1 = \"FHELinalg.apply_lookup_table\"(%0, %cst) : (tensor<4x4x4x!FHE.eint<4>>, tensor<16xi64>) \ +-> tensor<4x4x4x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_apply_multi_lut_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape_tensor = [4, 4, 4]; + let eint_tensor_type = get_eint_tensor_type(context, &shape_tensor, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = + create_constant_tensor_op(context, &[4, 4, 4, 16], &[0], 64); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let lut = mlirOperationGetResult(constant_int_tensor_op, 0); + // create lut op + let lut_op = create_fhelinalg_apply_multi_lut_op( + context, + eint_tensor_value, + lut, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, lut_op); + + let printed_op = print_mlir_operation_to_string(lut_op); + let expected_op = "%1 = \"FHELinalg.apply_multi_lookup_table\"(%0, %cst) : (tensor<4x4x4x!FHE.eint<4>>, tensor<4x4x4x16xi64>) \ +-> tensor<4x4x4x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_apply_mapped_lut_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape_tensor = [4, 4, 4]; + let eint_tensor_type = get_eint_tensor_type(context, &shape_tensor, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &[5, 16], &[0], 64); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let lut = mlirOperationGetResult(constant_int_tensor_op, 0); + // create map tensor + let constant_int_map_tensor_op = + create_constant_tensor_op(context, &[4, 4, 4], &[0], 64); + mlirBlockAppendOwnedOperation(main_block, constant_int_map_tensor_op); + let map = mlirOperationGetResult(constant_int_map_tensor_op, 0); + // create lut op + let lut_op = create_fhelinalg_apply_mapped_lut_op( + context, + eint_tensor_value, + lut, + map, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, lut_op); + + let printed_op = print_mlir_operation_to_string(lut_op); + let expected_op = "%3 = \"FHELinalg.apply_mapped_lookup_table\"(%0, %1, %2) : (tensor<4x4x4x!FHE.eint<4>>, tensor<5x16xi64>, tensor<4x4x4xi64>) \ +-> tensor<4x4x4x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_dot_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [100]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &shape, &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let int_tensor_value = mlirOperationGetResult(constant_int_tensor_op, 0); + // create dot_eint_int op + let dot_eint_int_op = create_fhelinalg_dot_eint_int_op( + context, + eint_tensor_value, + int_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, dot_eint_int_op); + + let printed_op = print_mlir_operation_to_string(dot_eint_int_op); + let expected_op = "%2 = \"FHELinalg.dot_eint_int\"(%0, %1) : (tensor<100x!FHE.eint<4>>, tensor<100xi5>) -> tensor<100x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_matmul_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [5, 5]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &shape, &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let int_tensor_value = mlirOperationGetResult(constant_int_tensor_op, 0); + // create matmul_eint_int op + let matmul_eint_int_op = create_fhelinalg_matmul_eint_int_op( + context, + eint_tensor_value, + int_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, matmul_eint_int_op); + + let printed_op = print_mlir_operation_to_string(matmul_eint_int_op); + let expected_op = "%1 = \"FHELinalg.matmul_eint_int\"(%0, %cst) : (tensor<5x5x!FHE.eint<4>>, tensor<5x5xi5>) -> tensor<5x5x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_matmul_int_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [5, 5]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create constant tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &shape, &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let int_tensor_value = mlirOperationGetResult(constant_int_tensor_op, 0); + // create matmul_int_eint op + let matmul_int_eint_op = create_fhelinalg_matmul_int_eint_op( + context, + int_tensor_value, + eint_tensor_value, + eint_tensor_type, + ); + mlirBlockAppendOwnedOperation(main_block, matmul_int_eint_op); + + let printed_op = print_mlir_operation_to_string(matmul_int_eint_op); + let expected_op = "%1 = \"FHELinalg.matmul_int_eint\"(%cst, %0) : (tensor<5x5xi5>, tensor<5x5x!FHE.eint<4>>) -> tensor<5x5x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_sum_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [5, 5]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create sum op + let sum_eint_op = create_fhelinalg_sum_op( + context, + eint_tensor_value, + Some(&[1]), + Some(false), + get_eint_tensor_type(context, &[5], 4), + ); + mlirBlockAppendOwnedOperation(main_block, sum_eint_op); + + let printed_op = print_mlir_operation_to_string(sum_eint_op); + let expected_op = "%1 = \"FHELinalg.sum\"(%0) {axes = [1], keep_dims = false} : (tensor<5x5x!FHE.eint<4>>) -> tensor<5x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_concat_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [3, 3]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create concat op + let concat_eint_op = create_fhelinalg_concat_op( + context, + &[eint_tensor_value, eint_tensor_value], + Some(0), + get_eint_tensor_type(context, &[6, 3], 4), + ); + mlirBlockAppendOwnedOperation(main_block, concat_eint_op); + + let printed_op = print_mlir_operation_to_string(concat_eint_op); + let expected_op = "%1 = \"FHELinalg.concat\"(%0, %0) {axis = 0 : i64} : (tensor<3x3x!FHE.eint<4>>, tensor<3x3x!FHE.eint<4>>) -> \ +tensor<6x3x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_conv2d_eint_int_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let eint_tensor_type = get_eint_tensor_type(context, &[100, 3, 28, 28], 4); + // create a zero tensor as input + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let input = mlirOperationGetResult(zero_tensor_op, 0); + // create constant weight tensor + let constant_int_tensor_op = + create_constant_tensor_op(context, &[4, 3, 14, 14], &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let weight = mlirOperationGetResult(constant_int_tensor_op, 0); + // create constant bias tensor + let constant_int_tensor_op = create_constant_tensor_op(context, &[4], &[0], 5); + mlirBlockAppendOwnedOperation(main_block, constant_int_tensor_op); + let bias = mlirOperationGetResult(constant_int_tensor_op, 0); + // create matmul_eint_int op + let conv2d_op = create_fhelinalg_conv2d_op( + context, + input, + weight, + Some(bias), + Some(&[0, 0, 0, 0]), + Some(&[1, 1]), + Some(&[1, 1]), + Some(1), + get_eint_tensor_type(context, &[100, 4, 15, 15], 4), + ); + mlirBlockAppendOwnedOperation(main_block, conv2d_op); + + let printed_op = print_mlir_operation_to_string(conv2d_op); + let expected_op = "%1 = \"FHELinalg.conv2d\"(%0, %cst, %cst_0) {dilations = dense<1> : tensor<2xi64>, group = 1 : i64, \ +padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<4>>, tensor<4x3x14x14xi5>, tensor<4xi5>) \ +-> tensor<100x4x15x15x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_transpose_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [2, 3, 4, 5]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create transpose op + let transpose_eint_op = create_fhelinalg_transpose_op( + context, + eint_tensor_value, + Some(&[1, 3, 0, 2]), + get_eint_tensor_type(context, &[3, 5, 2, 4], 4), + ); + mlirBlockAppendOwnedOperation(main_block, transpose_eint_op); + + let printed_op = print_mlir_operation_to_string(transpose_eint_op); + let expected_op = "%1 = \"FHELinalg.transpose\"(%0) {axes = [1, 3, 0, 2]} : (tensor<2x3x4x5x!FHE.eint<4>>) -> tensor<3x5x2x4x!FHE.eint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_from_element_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 2-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 2); + assert!(!eint_or_error.isError); + let eint2_type = eint_or_error.type_; + // create a zero eint + let zero_op = create_fhe_zero_eint_tensor_op(context, eint2_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let value = mlirOperationGetResult(zero_op, 0); + // create from element op + let from_element_op = create_fhelinalg_from_element_op(context, value); + mlirBlockAppendOwnedOperation(main_block, from_element_op); + + let printed_op = print_mlir_operation_to_string(from_element_op); + let expected_op = + "%1 = \"FHELinalg.from_element\"(%0) : (!FHE.eint<2>) -> tensor<1x!FHE.eint<2>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_to_signed_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [2, 3, 4, 5]; + let eint_tensor_type = get_eint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create to_signed op + let to_signed_op = create_fhelinalg_to_signed_op(context, eint_tensor_value); + mlirBlockAppendOwnedOperation(main_block, to_signed_op); + + let printed_op = print_mlir_operation_to_string(to_signed_op); + let expected_op = "%1 = \"FHELinalg.to_signed\"(%0) : (tensor<2x3x4x5x!FHE.eint<4>>) -> tensor<2x3x4x5x!FHE.esint<4>>"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_to_unsigned_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + // register the FHELinalg dialect + let fhelinalg_handle = mlirGetDialectHandle__fhelinalg__(); + mlirDialectHandleLoadDialect(fhelinalg_handle, context); + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + + // create a 4-bit eint tensor type + let shape = [2, 3, 4, 5]; + let eint_tensor_type = get_esint_tensor_type(context, &shape, 4); + // create a zero tensor + let zero_tensor_op = create_fhe_zero_eint_tensor_op(context, eint_tensor_type); + mlirBlockAppendOwnedOperation(main_block, zero_tensor_op); + let eint_tensor_value = mlirOperationGetResult(zero_tensor_op, 0); + // create to_unsigned op + let to_unsigned_op = create_fhelinalg_to_unsigned_op(context, eint_tensor_value); + mlirBlockAppendOwnedOperation(main_block, to_unsigned_op); + + let printed_op = print_mlir_operation_to_string(to_unsigned_op); + let expected_op = "%1 = \"FHELinalg.to_unsigned\"(%0) : (tensor<2x3x4x5x!FHE.esint<4>>) -> tensor<2x3x4x5x!FHE.eint<4>>"; assert_eq!(printed_op, expected_op); } } diff --git a/compiler/lib/Bindings/Rust/src/mlir.rs b/compiler/lib/Bindings/Rust/src/mlir.rs index 8810a60d0..bd59ec1b0 100644 --- a/compiler/lib/Bindings/Rust/src/mlir.rs +++ b/compiler/lib/Bindings/Rust/src/mlir.rs @@ -203,6 +203,71 @@ pub fn create_ret_op(context: MlirContext, ret_value: MlirValue) -> MlirOperatio create_op(context, "func.return", &[ret_value], &[], &[], false) } +pub fn create_constant_int_op(context: MlirContext, cst_value: i64, width: u32) -> MlirOperation { + unsafe { + let result_type = mlirIntegerTypeGet(context, width); + let value_str = CString::new("value").unwrap(); + let value_attr = mlirNamedAttributeGet( + mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())), + mlirIntegerAttrGet(result_type, cst_value), + ); + create_op( + context, + "arith.constant", + &[], + &[result_type], + &[value_attr], + true, + ) + } +} + +pub fn create_constant_flat_tensor_op( + context: MlirContext, + cst_table: &[i64], + bitwidth: u32, +) -> MlirOperation { + let shape = [cst_table.len().try_into().unwrap()]; + create_constant_tensor_op(context, &shape, cst_table, bitwidth) +} + +pub fn create_constant_tensor_op( + context: MlirContext, + shape: &[i64], + cst_table: &[i64], + bitwidth: u32, +) -> MlirOperation { + unsafe { + let result_type = mlirRankedTensorTypeGet( + shape.len().try_into().unwrap(), + shape.as_ptr(), + mlirIntegerTypeGet(context, bitwidth), + mlirAttributeGetNull(), + ); + let cst_table_attrs: Vec = cst_table + .into_iter() + .map(|value| mlirIntegerAttrGet(mlirIntegerTypeGet(context, bitwidth), *value)) + .collect(); + let value_str = CString::new("value").unwrap(); + let value_attr = mlirNamedAttributeGet( + mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())), + mlirDenseElementsAttrGet( + result_type, + cst_table.len().try_into().unwrap(), + cst_table_attrs.as_ptr(), + ), + ); + create_op( + context, + "arith.constant", + &[], + &[result_type], + &[value_attr], + true, + ) + } +} + #[cfg(test)] mod test { use super::*; @@ -301,4 +366,64 @@ module { ); } } + + #[test] + fn test_constant_flat_tensor() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // create a constant flat tensor + let contant_flat_tensor_op = create_constant_flat_tensor_op(context, &[0, 1, 2, 3], 64); + + let printed_op = print_mlir_operation_to_string(contant_flat_tensor_op); + let expected_op = "%cst = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>\n"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_constant_tensor() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // create a constant tensor + let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0, 1, 2, 3], 64); + + let printed_op = print_mlir_operation_to_string(contant_tensor_op); + let expected_op = "%cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>\n"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_constant_tensor_with_signle_elem() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // create a constant tensor + let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0], 7); + + let printed_op = print_mlir_operation_to_string(contant_tensor_op); + let expected_op = "%cst = arith.constant dense<0> : tensor<2x2xi7>\n"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_constant_int() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // create a constant flat tensor + let contant_int_op = create_constant_int_op(context, 73, 10); + + let printed_op = print_mlir_operation_to_string(contant_int_op); + let expected_op = "%c73_i10 = arith.constant 73 : i10\n"; + assert_eq!(printed_op, expected_op); + } + } }