mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(rust): complete FHE dialect API
helper functions to create all operations of the FHE dialect
This commit is contained in:
@@ -212,6 +212,177 @@ pub fn create_fhe_add_eint_op(
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_add_eint_int_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.add_eint_int",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_sub_eint_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.sub_eint",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_sub_eint_int_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.sub_eint_int",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_sub_int_eint_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(rhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.sub_int_eint",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_negate_eint_op(context: MlirContext, eint: MlirValue) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(eint)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.neg_eint",
|
||||
&[eint],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_mul_eint_int_op(
|
||||
context: MlirContext,
|
||||
lhs: MlirValue,
|
||||
rhs: MlirValue,
|
||||
) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(lhs)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.mul_eint_int",
|
||||
&[lhs, rhs],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_apply_lut_op(
|
||||
context: MlirContext,
|
||||
eint: MlirValue,
|
||||
lut: MlirValue,
|
||||
out_type: MlirType,
|
||||
) -> MlirOperation {
|
||||
create_op(
|
||||
context,
|
||||
"FHE.apply_lookup_table",
|
||||
&[eint, lut],
|
||||
[out_type].as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_fhe_to_signed_op(context: MlirContext, eint: MlirValue) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(eint)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.to_signed",
|
||||
&[eint],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_to_unsigned_op(context: MlirContext, esint: MlirValue) -> MlirOperation {
|
||||
unsafe {
|
||||
let results = [mlirValueGetType(esint)];
|
||||
// infer result type from operands
|
||||
create_op(
|
||||
context,
|
||||
"FHE.to_unsigned",
|
||||
&[esint],
|
||||
results.as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_fhe_zero_eint_op(context: MlirContext, out_type: MlirType) -> MlirOperation {
|
||||
create_op(context, "FHE.zero", &[], [out_type].as_slice(), &[], false)
|
||||
}
|
||||
|
||||
pub fn create_fhe_zero_eint_tensor_op(context: MlirContext, out_type: MlirType) -> MlirOperation {
|
||||
create_op(
|
||||
context,
|
||||
"FHE.zero_tensor",
|
||||
&[],
|
||||
[out_type].as_slice(),
|
||||
&[],
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn create_addi_op(context: MlirContext, lhs: MlirValue, rhs: MlirValue) -> MlirOperation {
|
||||
create_op(context, "arith.addi", &[lhs, rhs], &[], &[], true)
|
||||
}
|
||||
@@ -403,4 +574,60 @@ module {
|
||||
assert_eq!(printed_module, expected_module);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zero_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
let printed_op = print_mlir_operation_to_string(zero_op);
|
||||
let expected_op = "%0 = \"FHE.zero\"() : () -> !FHE.eint<6>\n";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_eint_op() {
|
||||
unsafe {
|
||||
let context = mlirContextCreate();
|
||||
mlirRegisterAllDialects(context);
|
||||
|
||||
// register the FHE dialect
|
||||
let fhe_handle = mlirGetDialectHandle__fhe__();
|
||||
mlirDialectHandleLoadDialect(fhe_handle, context);
|
||||
|
||||
// create a 6-bit eint type
|
||||
let eint_or_error = fheEncryptedIntegerTypeGetChecked(context, 6);
|
||||
assert!(!eint_or_error.isError);
|
||||
let eint6_type = eint_or_error.type_;
|
||||
|
||||
// create module for ops
|
||||
let location = mlirLocationUnknownGet(context);
|
||||
let module = mlirModuleCreateEmpty(location);
|
||||
let main_block = mlirModuleGetBody(module);
|
||||
// create an encrypted integer via a zero_op
|
||||
let zero_op = create_fhe_zero_eint_op(context, eint6_type);
|
||||
mlirBlockAppendOwnedOperation(main_block, zero_op);
|
||||
let eint_value = mlirOperationGetResult(zero_op, 0);
|
||||
// add eint with itself
|
||||
let add_eint_op = create_fhe_add_eint_op(context, eint_value, eint_value);
|
||||
mlirBlockAppendOwnedOperation(main_block, add_eint_op);
|
||||
|
||||
let printed_op = print_mlir_operation_to_string(add_eint_op);
|
||||
let expected_op =
|
||||
"%1 = \"FHE.add_eint\"(%0, %0) : (!FHE.eint<6>, !FHE.eint<6>) -> !FHE.eint<6>";
|
||||
assert_eq!(printed_op, expected_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user