test(rust): complete tests of rust bindings

updated also the API to make it easier to use by:
- creating MLIR components from native rust types instead of require
  MLIR components in the API
- adding helpers around the creation of standard dialects
This commit is contained in:
youben11
2022-11-15 19:03:08 +01:00
committed by Ayoub Benaissa
parent 5b46a74b7f
commit 52d5d908bb
3 changed files with 1602 additions and 97 deletions

View File

@@ -147,9 +147,51 @@ pub fn create_fhe_apply_lut_op(
)
}
#[derive(Debug)]
pub enum FHEError {
InvalidFHEType,
InvalidWidth,
}
pub fn convert_eint_to_esint_type(
context: MlirContext,
eint_type: MlirType,
) -> Result<MlirType, FHEError> {
unsafe {
let width = fheTypeIntegerWidthGet(eint_type);
if width == 0 {
return Err(FHEError::InvalidFHEType);
}
let type_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, width);
if type_or_error.isError {
Err(FHEError::InvalidWidth)
} else {
Ok(type_or_error.type_)
}
}
}
pub fn convert_esint_to_eint_type(
context: MlirContext,
esint_type: MlirType,
) -> Result<MlirType, FHEError> {
unsafe {
let width = fheTypeIntegerWidthGet(esint_type);
if width == 0 {
return Err(FHEError::InvalidFHEType);
}
let type_or_error = fheEncryptedIntegerTypeGetChecked(context, width);
if type_or_error.isError {
Err(FHEError::InvalidWidth)
} else {
Ok(type_or_error.type_)
}
}
}
pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(eint)];
let results = [convert_eint_to_esint_type(context, mlirValueGetType(eint)).unwrap()];
// infer result type from operands
create_op(
context,
@@ -164,7 +206,7 @@ pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOpe
pub fn create_fhe_to_unsigned_op(context: MlirContext, esint: MlirValue) -> MlirOperation {
unsafe {
let results = [mlirValueGetType(esint)];
let results = [convert_esint_to_eint_type(context, mlirValueGetType(esint)).unwrap()];
// infer result type from operands
create_op(
context,
@@ -226,12 +268,34 @@ mod test {
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
assert!(!eint_or_error.isError);
let eint = eint_or_error.type_;
assert!(fheTypeIsAnEncryptedIntegerType(eint));
assert!(!fheTypeIsAnEncryptedSignedIntegerType(eint));
assert_eq!(fheTypeIntegerWidthGet(eint), 5);
let printed_eint = super::print_mlir_type_to_string(eint);
let expected_eint = "!FHE.eint<5>";
assert_eq!(printed_eint, expected_eint);
}
}
#[test]
fn test_valid_fhe_esint_type() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 5);
assert!(!esint_or_error.isError);
let esint = esint_or_error.type_;
assert!(fheTypeIsAnEncryptedSignedIntegerType(esint));
assert!(!fheTypeIsAnEncryptedIntegerType(esint));
assert_eq!(fheTypeIntegerWidthGet(esint), 5);
let printed_esint = super::print_mlir_type_to_string(esint);
let expected_esint = "!FHE.esint<5>";
assert_eq!(printed_esint, expected_esint);
}
}
#[test]
fn test_fhe_func() {
unsafe {
@@ -313,6 +377,37 @@ module {
}
}
#[test]
fn test_zero_tensor_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 4-bit eint tensor type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 4);
assert!(!eint_or_error.isError);
let eint = eint_or_error.type_;
let shape: [i64; 3] = [60, 66, 73];
let location = mlirLocationUnknownGet(context);
let eint_tensor = mlirRankedTensorTypeGetChecked(
location,
3,
shape.as_ptr(),
eint,
mlirAttributeGetNull(),
);
let zero_op = create_fhe_zero_eint_tensor_op(context, eint_tensor);
let printed_op = print_mlir_operation_to_string(zero_op);
let expected_op = "%0 = \"FHE.zero_tensor\"() : () -> tensor<60x66x73x!FHE.eint<4>>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_add_eint_op() {
unsafe {
@@ -346,4 +441,327 @@ module {
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_add_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an int via a constant op
let cst_op = create_constant_int_op(context, 73, 7);
mlirBlockAppendOwnedOperation(main_block, cst_op);
let int_value = mlirOperationGetResult(cst_op, 0);
// add eint int
let add_eint_int_op = create_fhe_add_eint_int_op(context, eint_value, int_value);
mlirBlockAppendOwnedOperation(main_block, add_eint_int_op);
let printed_op = print_mlir_operation_to_string(add_eint_int_op);
let expected_op =
"%1 = \"FHE.add_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_sub_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// sub eint with itself
let sub_eint_op = create_fhe_sub_eint_op(context, eint_value, eint_value);
mlirBlockAppendOwnedOperation(main_block, sub_eint_op);
let printed_op = print_mlir_operation_to_string(sub_eint_op);
let expected_op =
"%1 = \"FHE.sub_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_sub_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an int via a constant op
let cst_op = create_constant_int_op(context, 73, 7);
mlirBlockAppendOwnedOperation(main_block, cst_op);
let int_value = mlirOperationGetResult(cst_op, 0);
// sub eint int
let sub_eint_int_op = create_fhe_sub_eint_int_op(context, eint_value, int_value);
mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op);
let printed_op = print_mlir_operation_to_string(sub_eint_int_op);
let expected_op =
"%1 = \"FHE.sub_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_sub_int_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an int via a constant op
let cst_op = create_constant_int_op(context, 73, 7);
mlirBlockAppendOwnedOperation(main_block, cst_op);
let int_value = mlirOperationGetResult(cst_op, 0);
// sub int eint
let sub_eint_int_op = create_fhe_sub_int_eint_op(context, int_value, eint_value);
mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op);
let printed_op = print_mlir_operation_to_string(sub_eint_int_op);
let expected_op =
"%1 = \"FHE.sub_int_eint\"(%c-55_i7, %0) : (i7, !FHE.eint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_negate_eint_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// negate eint
let neg_eint_op = create_fhe_negate_eint_op(context, eint_value);
mlirBlockAppendOwnedOperation(main_block, neg_eint_op);
let printed_op = print_mlir_operation_to_string(neg_eint_op);
let expected_op = "%1 = \"FHE.neg_eint\"(%0) : (!FHE.eint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_mul_eint_int_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an int via a constant op
let cst_op = create_constant_int_op(context, 73, 7);
mlirBlockAppendOwnedOperation(main_block, cst_op);
let int_value = mlirOperationGetResult(cst_op, 0);
// mul eint int
let mul_eint_int_op = create_fhe_mul_eint_int_op(context, eint_value, int_value);
mlirBlockAppendOwnedOperation(main_block, mul_eint_int_op);
let printed_op = print_mlir_operation_to_string(mul_eint_int_op);
let expected_op =
"%1 = \"FHE.mul_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_to_signed_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// to signed
let to_signed_op = create_fhe_to_signed_op(context, eint_value);
mlirBlockAppendOwnedOperation(main_block, to_signed_op);
let printed_op = print_mlir_operation_to_string(to_signed_op);
let expected_op = "%1 = \"FHE.to_signed\"(%0) : (!FHE.eint<6>) -> !FHE.esint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_to_unsigned_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit esint type
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 6);
assert!(!esint_or_error.isError);
let esint6_type = esint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, esint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let esint_value = mlirOperationGetResult(zero_op, 0);
// to unsigned
let to_unsigned_op = create_fhe_to_unsigned_op(context, esint_value);
mlirBlockAppendOwnedOperation(main_block, to_unsigned_op);
let printed_op = print_mlir_operation_to_string(to_unsigned_op);
let expected_op = "%1 = \"FHE.to_unsigned\"(%0) : (!FHE.esint<6>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_apply_lut_op() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// register the FHE dialect
let fhe_handle = mlirGetDialectHandle__fhe__();
mlirDialectHandleLoadDialect(fhe_handle, context);
// create a 6-bit eint type
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
assert!(!eint_or_error.isError);
let eint6_type = eint_or_error.type_;
// create module for ops
let location = mlirLocationUnknownGet(context);
let module = mlirModuleCreateEmpty(location);
let main_block = mlirModuleGetBody(module);
// create an encrypted integer via a zero_op
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
mlirBlockAppendOwnedOperation(main_block, zero_op);
let eint_value = mlirOperationGetResult(zero_op, 0);
// create an lut
let table: [i64; 64] = [0; 64];
let constant_lut_op = create_constant_flat_tensor_op(context, &table, 64);
mlirBlockAppendOwnedOperation(main_block, constant_lut_op);
let lut = mlirOperationGetResult(constant_lut_op, 0);
// LUT op
let apply_lut_op = create_fhe_apply_lut_op(context, eint_value, lut, eint6_type);
mlirBlockAppendOwnedOperation(main_block, apply_lut_op);
let printed_op = print_mlir_operation_to_string(apply_lut_op);
let expected_op = "%1 = \"FHE.apply_lookup_table\"(%0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>";
assert_eq!(printed_op, expected_op);
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -203,6 +203,71 @@ pub fn create_ret_op(context: MlirContext, ret_value: MlirValue) -> MlirOperatio
create_op(context, "func.return", &[ret_value], &[], &[], false)
}
pub fn create_constant_int_op(context: MlirContext, cst_value: i64, width: u32) -> MlirOperation {
unsafe {
let result_type = mlirIntegerTypeGet(context, width);
let value_str = CString::new("value").unwrap();
let value_attr = mlirNamedAttributeGet(
mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())),
mlirIntegerAttrGet(result_type, cst_value),
);
create_op(
context,
"arith.constant",
&[],
&[result_type],
&[value_attr],
true,
)
}
}
pub fn create_constant_flat_tensor_op(
context: MlirContext,
cst_table: &[i64],
bitwidth: u32,
) -> MlirOperation {
let shape = [cst_table.len().try_into().unwrap()];
create_constant_tensor_op(context, &shape, cst_table, bitwidth)
}
pub fn create_constant_tensor_op(
context: MlirContext,
shape: &[i64],
cst_table: &[i64],
bitwidth: u32,
) -> MlirOperation {
unsafe {
let result_type = mlirRankedTensorTypeGet(
shape.len().try_into().unwrap(),
shape.as_ptr(),
mlirIntegerTypeGet(context, bitwidth),
mlirAttributeGetNull(),
);
let cst_table_attrs: Vec<MlirAttribute> = cst_table
.into_iter()
.map(|value| mlirIntegerAttrGet(mlirIntegerTypeGet(context, bitwidth), *value))
.collect();
let value_str = CString::new("value").unwrap();
let value_attr = mlirNamedAttributeGet(
mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())),
mlirDenseElementsAttrGet(
result_type,
cst_table.len().try_into().unwrap(),
cst_table_attrs.as_ptr(),
),
);
create_op(
context,
"arith.constant",
&[],
&[result_type],
&[value_attr],
true,
)
}
}
#[cfg(test)]
mod test {
use super::*;
@@ -301,4 +366,64 @@ module {
);
}
}
#[test]
fn test_constant_flat_tensor() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// create a constant flat tensor
let contant_flat_tensor_op = create_constant_flat_tensor_op(context, &[0, 1, 2, 3], 64);
let printed_op = print_mlir_operation_to_string(contant_flat_tensor_op);
let expected_op = "%cst = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_constant_tensor() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// create a constant tensor
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0, 1, 2, 3], 64);
let printed_op = print_mlir_operation_to_string(contant_tensor_op);
let expected_op = "%cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_constant_tensor_with_signle_elem() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// create a constant tensor
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0], 7);
let printed_op = print_mlir_operation_to_string(contant_tensor_op);
let expected_op = "%cst = arith.constant dense<0> : tensor<2x2xi7>\n";
assert_eq!(printed_op, expected_op);
}
}
#[test]
fn test_constant_int() {
unsafe {
let context = mlirContextCreate();
mlirRegisterAllDialects(context);
// create a constant flat tensor
let contant_int_op = create_constant_int_op(context, 73, 10);
let printed_op = print_mlir_operation_to_string(contant_int_op);
let expected_op = "%c73_i10 = arith.constant 73 : i10\n";
assert_eq!(printed_op, expected_op);
}
}
}