mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
test(rust): complete tests of rust bindings
updated also the API to make it easier to use by: - creating MLIR components from native rust types instead of require MLIR components in the API - adding helpers around the creation of standard dialects
This commit is contained in:
@@ -147,9 +147,51 @@ pub fn create_fhe_apply_lut_op(
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FHEError {
|
||||
InvalidFHEType,
|
||||
InvalidWidth,
|
||||
}
|
||||
|
||||
pub fn convert_eint_to_esint_type(
|
||||
context: MlirContext,
|
||||
eint_type: MlirType,
|
||||
) -> Result<MlirType, FHEError> {
|
||||
unsafe {
|
||||
let width = fheTypeIntegerWidthGet(eint_type);
|
||||
if width == 0 {
|
||||
return Err(FHEError::InvalidFHEType);
|
||||
}
|
||||
let type_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, width);
|
||||
if type_or_error.isError {
|
||||
Err(FHEError::InvalidWidth)
|
||||
} else {
|
||||
Ok(type_or_error.type_)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn convert_esint_to_eint_type(
|
||||
context: MlirContext,
|
||||
esint_type: MlirType,
|
||||
) -> Result<MlirType, FHEError> {
|
||||
unsafe {
|
||||
let width = fheTypeIntegerWidthGet(esint_type);
|
||||
if width == 0 {
|
||||
return Err(FHEError::InvalidFHEType);
|
||||
}
|
||||
let type_or_error = fheEncryptedIntegerTypeGetChecked(context, width);
|
||||
if type_or_error.isError {
|
||||
Err(FHEError::InvalidWidth)
|
||||
} else {
|
||||
Ok(type_or_error.type_)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(eint)];
|
||||
let results = [convert_eint_to_esint_type(context, mlirValueGetType(eint)).unwrap()];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
@@ -164,7 +206,7 @@ pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOpe
|
||||
|
||||
pub fn create_fhe_to_unsigned_op(context: MlirContext, esint: MlirValue) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(esint)];
|
||||
let results = [convert_esint_to_eint_type(context, mlirValueGetType(esint)).unwrap()];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
@@ -226,12 +268,34 @@ mod test {
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 5);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint = eint_or_error.type_;
|
||||
assert!(fheTypeIsAnEncryptedIntegerType(eint));
|
||||
assert!(!fheTypeIsAnEncryptedSignedIntegerType(eint));
|
||||
assert_eq!(fheTypeIntegerWidthGet(eint), 5);
|
||||
let printed_eint = super::print_mlir_type_to_string(eint);
|
||||
let expected_eint = "!FHE.eint<5>";
|
||||
assert_eq!(printed_eint, expected_eint);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_fhe_esint_type() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 5);
|
||||
assert!(!esint_or_error.isError);
|
||||
let esint = esint_or_error.type_;
|
||||
assert!(fheTypeIsAnEncryptedSignedIntegerType(esint));
|
||||
assert!(!fheTypeIsAnEncryptedIntegerType(esint));
|
||||
assert_eq!(fheTypeIntegerWidthGet(esint), 5);
|
||||
let printed_esint = super::print_mlir_type_to_string(esint);
|
||||
let expected_esint = "!FHE.esint<5>";
|
||||
assert_eq!(printed_esint, expected_esint);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fhe_func() {
|
||||
unsafe {
|
||||
@@ -313,6 +377,37 @@ module {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_tensor_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 4-bit eint tensor type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 4);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint = eint_or_error.type_;
|
||||
let shape: [i64; 3] = [60, 66, 73];
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let eint_tensor = mlirRankedTensorTypeGetChecked(
|
||||
location,
|
||||
3,
|
||||
shape.as_ptr(),
|
||||
eint,
|
||||
mlirAttributeGetNull(),
|
||||
);
|
||||
|
||||
let zero_op = create_fhe_zero_eint_tensor_op(context, eint_tensor);
|
||||
let printed_op = print_mlir_operation_to_string(zero_op);
|
||||
let expected_op = "%0 = \"FHE.zero_tensor\"() : () -> tensor<60x66x73x!FHE.eint<4>>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_eint_op() {
|
||||
unsafe {
|
||||
@@ -346,4 +441,327 @@ module {
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an int via a constant op
|
||||
let cst_op = create_constant_int_op(context, 73, 7);
|
||||
mlirBlockAppendOwnedOperation(main_block, cst_op);
|
||||
let int_value = mlirOperationGetResult(cst_op, 0);
|
||||
// add eint int
|
||||
let add_eint_int_op = create_fhe_add_eint_int_op(context, eint_value, int_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, add_eint_int_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(add_eint_int_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.add_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// sub eint with itself
|
||||
let sub_eint_op = create_fhe_sub_eint_op(context, eint_value, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, sub_eint_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(sub_eint_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.sub_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an int via a constant op
|
||||
let cst_op = create_constant_int_op(context, 73, 7);
|
||||
mlirBlockAppendOwnedOperation(main_block, cst_op);
|
||||
let int_value = mlirOperationGetResult(cst_op, 0);
|
||||
// sub eint int
|
||||
let sub_eint_int_op = create_fhe_sub_eint_int_op(context, eint_value, int_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(sub_eint_int_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.sub_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sub_int_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an int via a constant op
|
||||
let cst_op = create_constant_int_op(context, 73, 7);
|
||||
mlirBlockAppendOwnedOperation(main_block, cst_op);
|
||||
let int_value = mlirOperationGetResult(cst_op, 0);
|
||||
// sub int eint
|
||||
let sub_eint_int_op = create_fhe_sub_int_eint_op(context, int_value, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, sub_eint_int_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(sub_eint_int_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.sub_int_eint\"(%c-55_i7, %0) : (i7, !FHE.eint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_negate_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// negate eint
|
||||
let neg_eint_op = create_fhe_negate_eint_op(context, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, neg_eint_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(neg_eint_op);
|
||||
let expected_op = "%1 = \"FHE.neg_eint\"(%0) : (!FHE.eint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_eint_int_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an int via a constant op
|
||||
let cst_op = create_constant_int_op(context, 73, 7);
|
||||
mlirBlockAppendOwnedOperation(main_block, cst_op);
|
||||
let int_value = mlirOperationGetResult(cst_op, 0);
|
||||
// mul eint int
|
||||
let mul_eint_int_op = create_fhe_mul_eint_int_op(context, eint_value, int_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, mul_eint_int_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(mul_eint_int_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.mul_eint_int\"(%0, %c-55_i7) : (!FHE.eint<6>, i7) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_signed_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// to signed
|
||||
let to_signed_op = create_fhe_to_signed_op(context, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, to_signed_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(to_signed_op);
|
||||
let expected_op = "%1 = \"FHE.to_signed\"(%0) : (!FHE.eint<6>) -> !FHE.esint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_unsigned_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit esint type
|
||||
let esint_or_error = fheEncryptedSignedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!esint_or_error.isError);
|
||||
let esint6_type = esint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, esint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let esint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// to unsigned
|
||||
let to_unsigned_op = create_fhe_to_unsigned_op(context, esint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, to_unsigned_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(to_unsigned_op);
|
||||
let expected_op = "%1 = \"FHE.to_unsigned\"(%0) : (!FHE.esint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_lut_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// create an lut
|
||||
let table: [i64; 64] = [0; 64];
|
||||
let constant_lut_op = create_constant_flat_tensor_op(context, &table, 64);
|
||||
mlirBlockAppendOwnedOperation(main_block, constant_lut_op);
|
||||
let lut = mlirOperationGetResult(constant_lut_op, 0);
|
||||
// LUT op
|
||||
let apply_lut_op = create_fhe_apply_lut_op(context, eint_value, lut, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, apply_lut_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(apply_lut_op);
|
||||
let expected_op = "%1 = \"FHE.apply_lookup_table\"(%0, %cst) : (!FHE.eint<6>, tensor<64xi64>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -203,6 +203,71 @@ pub fn create_ret_op(context: MlirContext, ret_value: MlirValue) -> MlirOperatio
|
||||
create_op(context, "func.return", &[ret_value], &[], &[], false)
|
||||
}
|
||||
|
||||
pub fn create_constant_int_op(context: MlirContext, cst_value: i64, width: u32) -> MlirOperation {
|
||||
unsafe {
|
||||
let result_type = mlirIntegerTypeGet(context, width);
|
||||
let value_str = CString::new("value").unwrap();
|
||||
let value_attr = mlirNamedAttributeGet(
|
||||
mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())),
|
||||
mlirIntegerAttrGet(result_type, cst_value),
|
||||
);
|
||||
create_op(
|
||||
context,
|
||||
"arith.constant",
|
||||
&[],
|
||||
&[result_type],
|
||||
&[value_attr],
|
||||
true,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_constant_flat_tensor_op(
|
||||
context: MlirContext,
|
||||
cst_table: &[i64],
|
||||
bitwidth: u32,
|
||||
) -> MlirOperation {
|
||||
let shape = [cst_table.len().try_into().unwrap()];
|
||||
create_constant_tensor_op(context, &shape, cst_table, bitwidth)
|
||||
}
|
||||
|
||||
pub fn create_constant_tensor_op(
|
||||
context: MlirContext,
|
||||
shape: &[i64],
|
||||
cst_table: &[i64],
|
||||
bitwidth: u32,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let result_type = mlirRankedTensorTypeGet(
|
||||
shape.len().try_into().unwrap(),
|
||||
shape.as_ptr(),
|
||||
mlirIntegerTypeGet(context, bitwidth),
|
||||
mlirAttributeGetNull(),
|
||||
);
|
||||
let cst_table_attrs: Vec<MlirAttribute> = cst_table
|
||||
.into_iter()
|
||||
.map(|value| mlirIntegerAttrGet(mlirIntegerTypeGet(context, bitwidth), *value))
|
||||
.collect();
|
||||
let value_str = CString::new("value").unwrap();
|
||||
let value_attr = mlirNamedAttributeGet(
|
||||
mlirIdentifierGet(context, mlirStringRefCreateFromCString(value_str.as_ptr())),
|
||||
mlirDenseElementsAttrGet(
|
||||
result_type,
|
||||
cst_table.len().try_into().unwrap(),
|
||||
cst_table_attrs.as_ptr(),
|
||||
),
|
||||
);
|
||||
create_op(
|
||||
context,
|
||||
"arith.constant",
|
||||
&[],
|
||||
&[result_type],
|
||||
&[value_attr],
|
||||
true,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
@@ -301,4 +366,64 @@ module {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_flat_tensor() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// create a constant flat tensor
|
||||
let contant_flat_tensor_op = create_constant_flat_tensor_op(context, &[0, 1, 2, 3], 64);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(contant_flat_tensor_op);
|
||||
let expected_op = "%cst = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi64>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_tensor() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// create a constant tensor
|
||||
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0, 1, 2, 3], 64);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(contant_tensor_op);
|
||||
let expected_op = "%cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi64>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_tensor_with_signle_elem() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// create a constant tensor
|
||||
let contant_tensor_op = create_constant_tensor_op(context, &[2, 2], &[0], 7);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(contant_tensor_op);
|
||||
let expected_op = "%cst = arith.constant dense<0> : tensor<2x2xi7>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_int() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// create a constant flat tensor
|
||||
let contant_int_op = create_constant_int_op(context, 73, 10);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(contant_int_op);
|
||||
let expected_op = "%c73_i10 = arith.constant 73 : i10\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user