diff --git a/compiler/lib/Bindings/Rust/src/lib.rs b/compiler/lib/Bindings/Rust/src/lib.rs index 7dd438b2e..33fe2c93b 100644 --- a/compiler/lib/Bindings/Rust/src/lib.rs +++ b/compiler/lib/Bindings/Rust/src/lib.rs @@ -212,6 +212,177 @@ pub fn create_fhe_add_eint_op( } } +pub fn create_fhe_add_eint_int_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(lhs)]; + // infer result type from operands + create_op( + context, + "FHE.add_eint_int", + &[lhs, rhs], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhe_sub_eint_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(lhs)]; + // infer result type from operands + create_op( + context, + "FHE.sub_eint", + &[lhs, rhs], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhe_sub_eint_int_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(lhs)]; + // infer result type from operands + create_op( + context, + "FHE.sub_eint_int", + &[lhs, rhs], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhe_sub_int_eint_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(rhs)]; + // infer result type from operands + create_op( + context, + "FHE.sub_int_eint", + &[lhs, rhs], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhe_negate_eint_op(context: MlirContext, eint: MlirValue) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(eint)]; + // infer result type from operands + create_op( + context, + "FHE.neg_eint", + &[eint], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhe_mul_eint_int_op( + context: MlirContext, + lhs: MlirValue, + rhs: MlirValue, +) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(lhs)]; + // infer result type from operands + create_op( + context, + "FHE.mul_eint_int", + &[lhs, rhs], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhe_apply_lut_op( + context: MlirContext, + eint: MlirValue, + lut: MlirValue, + out_type: MlirType, +) -> MlirOperation { + create_op( + context, + "FHE.apply_lookup_table", + &[eint, lut], + [out_type].as_slice(), + &[], + false, + ) +} + +pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(eint)]; + // infer result type from operands + create_op( + context, + "FHE.to_signed", + &[eint], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhe_to_unsigned_op(context: MlirContext, esint: MlirValue) -> MlirOperation { + unsafe { + let results = [mlirValueGetType(esint)]; + // infer result type from operands + create_op( + context, + "FHE.to_unsigned", + &[esint], + results.as_slice(), + &[], + false, + ) + } +} + +pub fn create_fhe_zero_eint_op(context: MlirContext, out_type: MlirType) -> MlirOperation { + create_op(context, "FHE.zero", &[], [out_type].as_slice(), &[], false) +} + +pub fn create_fhe_zero_eint_tensor_op(context: MlirContext, out_type: MlirType) -> MlirOperation { + create_op( + context, + "FHE.zero_tensor", + &[], + [out_type].as_slice(), + &[], + false, + ) +} + pub fn create_addi_op(context: MlirContext, lhs: MlirValue, rhs: MlirValue) -> MlirOperation { create_op(context, "arith.addi", &[lhs, rhs], &[], &[], true) } @@ -403,4 +574,60 @@ module { assert_eq!(printed_module, expected_module); } } + + #[test] + fn test_zero_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + let printed_op = print_mlir_operation_to_string(zero_op); + let expected_op = "%0 = \"FHE.zero\"() : () -> !FHE.eint<6>\n"; + assert_eq!(printed_op, expected_op); + } + } + + #[test] + fn test_add_eint_op() { + unsafe { + let context = mlirContextCreate(); + mlirRegisterAllDialects(context); + + // register the FHE dialect + let fhe_handle = mlirGetDialectHandle__fhe__(); + mlirDialectHandleLoadDialect(fhe_handle, context); + + // create a 6-bit eint type + let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6); + assert!(!eint_or_error.isError); + let eint6_type = eint_or_error.type_; + + // create module for ops + let location = mlirLocationUnknownGet(context); + let module = mlirModuleCreateEmpty(location); + let main_block = mlirModuleGetBody(module); + // create an encrypted integer via a zero_op + let zero_op = create_fhe_zero_eint_op(context, eint6_type); + mlirBlockAppendOwnedOperation(main_block, zero_op); + let eint_value = mlirOperationGetResult(zero_op, 0); + // add eint with itself + let add_eint_op = create_fhe_add_eint_op(context, eint_value, eint_value); + mlirBlockAppendOwnedOperation(main_block, add_eint_op); + + let printed_op = print_mlir_operation_to_string(add_eint_op); + let expected_op = + "%1 = \"FHE.add_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>"; + assert_eq!(printed_op, expected_op); + } + } }