From 41c9f868031f19c24e4ec1ca7d12830a3f4e4a03 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 26 Aug 2022 14:17:58 +0300 Subject: [PATCH] feat: create encrypted signed integer type --- .../Dialect/FHE/IR/CMakeLists.txt | 4 + .../Dialect/FHE/IR/FHEInterfaces.td | 32 ++++ .../concretelang/Dialect/FHE/IR/FHEOps.h | 4 +- .../concretelang/Dialect/FHE/IR/FHEOps.td | 146 +++++++++----- .../concretelang/Dialect/FHE/IR/FHETypes.h | 2 + .../concretelang/Dialect/FHE/IR/FHETypes.td | 41 +++- .../Dialect/FHELinalg/IR/FHELinalgOps.td | 164 +++++++++++----- .../TensorOpsToLinalg.cpp | 180 ++++++++++++++++++ compiler/lib/Dialect/FHE/IR/FHEDialect.cpp | 34 +++- compiler/lib/Dialect/FHE/IR/FHEOps.cpp | 154 ++++++++++----- .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 155 +++++++++++---- .../FHELinalgToLinalg/to_signed.mlir | 18 ++ .../FHELinalgToLinalg/to_unsigned.mlir | 18 ++ .../Dialect/FHE/add_eint.invalid.mlir | 31 +++ .../Dialect/FHE/add_eint_int.invalid.mlir | 26 +++ .../Dialect/FHE/eint_error_p_too_small.mlir | 11 +- .../Dialect/FHE/mul_eint_int.invalid.mlir | 26 +++ .../Dialect/FHE/neg_eint.invalid.mlir | 15 ++ .../Dialect/FHE/op_add_eint_err_inputs.mlir | 7 - .../Dialect/FHE/op_add_eint_err_result.mlir | 7 - .../FHE/op_add_eint_int_err_inputs.mlir | 8 - .../FHE/op_add_eint_int_err_result.mlir | 8 - .../FHE/op_mul_eint_int_err_inputs.mlir | 8 - .../FHE/op_mul_eint_int_err_result.mlir | 8 - .../Dialect/FHE/op_neg_eint_err_result.mlir | 7 - .../FHE/op_sub_int_eint_err_inputs.mlir | 8 - .../FHE/op_sub_int_eint_err_result.mlir | 8 - .../tests/check_tests/Dialect/FHE/ops.mlir | 115 ++++++++++- .../Dialect/FHE/sub_eint.invalid.mlir | 31 +++ .../Dialect/FHE/sub_int_eint.invalid.mlir | 26 +++ .../Dialect/FHE/to_signed.invalid.mlir | 7 + .../Dialect/FHE/to_unsigned.invalid.mlir | 7 + .../Dialect/FHELinalg/concat.invalid.mlir | 6 +- .../Dialect/FHELinalg/dot.invalid.mlir | 4 +- .../Dialect/FHELinalg/sum.invalid.mlir | 2 +- .../Dialect/FHELinalg/to_signed.invalid.mlir | 15 ++ .../Dialect/FHELinalg/to_signed.mlir | 23 +++ .../FHELinalg/to_unsigned.invalid.mlir | 15 ++ .../Dialect/FHELinalg/to_unsigned.mlir | 23 +++ 39 files changed, 1144 insertions(+), 260 deletions(-) create mode 100644 compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td create mode 100644 compiler/tests/check_tests/Conversion/FHELinalgToLinalg/to_signed.mlir create mode 100644 compiler/tests/check_tests/Conversion/FHELinalgToLinalg/to_unsigned.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/add_eint.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/add_eint_int.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/mul_eint_int.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/neg_eint.invalid.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_add_eint_err_inputs.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_add_eint_err_result.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_add_eint_int_err_inputs.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_add_eint_int_err_result.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_mul_eint_int_err_inputs.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_mul_eint_int_err_result.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_neg_eint_err_result.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_sub_int_eint_err_inputs.mlir delete mode 100644 compiler/tests/check_tests/Dialect/FHE/op_sub_int_eint_err_result.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/sub_eint.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/sub_int_eint.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/to_signed.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHE/to_unsigned.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHELinalg/to_signed.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHELinalg/to_signed.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHELinalg/to_unsigned.invalid.mlir create mode 100644 compiler/tests/check_tests/Dialect/FHELinalg/to_unsigned.mlir diff --git a/compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt index 9760b5f5d..54e5a4799 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/FHE/IR/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS FHEInterfaces.td) +mlir_tablegen(FHETypesInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(FHETypesInterfaces.cpp.inc -gen-type-interface-defs) + set(LLVM_TARGET_DEFINITIONS FHEOps.td) mlir_tablegen(FHEOps.h.inc -gen-op-decls) mlir_tablegen(FHEOps.cpp.inc -gen-op-defs) diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td new file mode 100644 index 000000000..cb0afe564 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEInterfaces.td @@ -0,0 +1,32 @@ +#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES +#define CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES + +include "mlir/IR/OpBase.td" + +def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> { + let cppNamespace = "mlir::concretelang::FHE"; + + let description = [{ + Interface for encapsulating the common properties of encrypted integer types. + }]; + + let methods = [ + InterfaceMethod< + /*description=*/"Get bit-width of the integer.", + /*retTy=*/"unsigned", + /*methodName=*/"getWidth" + >, + InterfaceMethod< + /*description=*/"Get whether the integer is signed.", + /*retTy=*/"bool", + /*methodName=*/"isSigned" + >, + InterfaceMethod< + /*description=*/"Get whether the integer is unsigned.", + /*retTy=*/"bool", + /*methodName=*/"isUnsigned" + > + ]; +} + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.h b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.h index ac91f13cc..e58216dc4 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.h +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.h @@ -18,10 +18,10 @@ namespace concretelang { namespace FHE { bool verifyEncryptedIntegerInputAndResultConsistency( - Operation &op, EncryptedIntegerType &input, EncryptedIntegerType &result); + Operation &op, FheIntegerInterface &input, FheIntegerInterface &result); bool verifyEncryptedIntegerAndIntegerInputsConsistency(Operation &op, - EncryptedIntegerType &a, + FheIntegerInterface &a, IntegerType &b); /// Shared error message for all ApplyLookupTable variant Op (several Dialect) diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 87de2edab..ca6930f31 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -27,14 +27,14 @@ def FHE_ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> { Example: ```mlir "FHE.zero"() : () -> !FHE.eint<2> + "FHE.zero"() : () -> !FHE.esint<2> ``` }]; let arguments = (ins); - let results = (outs FHE_EncryptedIntegerType:$out); + let results = (outs FHE_AnyEncryptedInteger:$out); } - def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> { let summary = "Creates a new tensor with all elements initialized to an encrypted zero."; @@ -44,36 +44,38 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> { Example: ```mlir %tensor = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<4>> + %tensor = "FHE.zero_tensor"() : () -> tensor<5x!FHE.esint<4>> ``` }]; let arguments = (ins); - let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); + let results = (outs Type.predicate, HasStaticShapePred]>>:$tensor); } def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> { - let summary = "Adds an encrypted integer and a clear integer"; let description = [{ Adds an encrypted integer and a clear integer. The clear integer must have at most one more bit than the encrypted integer - and the result must have the same width than the encrypted integer. + and the result must have the same width and the same signedness as the encrypted integer. Example: ```mlir // ok "FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + "FHE.add_eint_int"(%a, %i) : (!FHE.esint<2>, i3) -> !FHE.esint<2> // error "FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i4) -> !FHE.eint<2> "FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<3> + "FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.esint<2> ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$a, AnyInteger:$b); - let results = (outs FHE_EncryptedIntegerType); + let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b); + let results = (outs FHE_AnyEncryptedInteger); let builders = [ OpBuilder<(ins "Value":$a, "Value":$b), [{ @@ -86,26 +88,28 @@ def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> { } def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> { - let summary = "Adds two encrypted integers"; - let description = [{ + let description = [{ Adds two encrypted integers - The encrypted integers and the result must have the same width. + The encrypted integers and the result must have the same width and the same signedness. Example: ```mlir // ok "FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + "FHE.add_eint"(%a, %b): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>) // error "FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>) "FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>) + "FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>) + "FHE.add_eint"(%a, %b): (!FHE.esint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$a, FHE_EncryptedIntegerType:$b); - let results = (outs FHE_EncryptedIntegerType); + let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b); + let results = (outs FHE_AnyEncryptedInteger); let builders = [ OpBuilder<(ins "Value":$a, "Value":$b), [{ @@ -117,27 +121,28 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> { } def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> { - - let summary = "Substract a clear integer and an encrypted integer"; + let summary = "Subtract an encrypted integer from a clear integer"; let description = [{ - Substract a clear integer and an encrypted integer. - The clear integer must have at most one more bit than the encrypted integer - and the result must have the same width than the encrypted integer. + Subtract an encrypted integer from a clear integer. + The clear integer must have one more bit than the encrypted integer + and the result must have the same width and the same signedness as the encrypted integer. Example: ```mlir // ok "FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.eint<2> + "FHE.sub_int_eint"(%i, %a) : (i3, !FHE.esint<2>) -> !FHE.esint<2> // error "FHE.sub_int_eint"(%i, %a) : (i4, !FHE.eint<2>) -> !FHE.eint<2> "FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.eint<3> + "FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.esint<2> ``` }]; - let arguments = (ins AnyInteger:$a, FHE_EncryptedIntegerType:$b); - let results = (outs FHE_EncryptedIntegerType); + let arguments = (ins AnyInteger:$a, FHE_AnyEncryptedInteger:$b); + let results = (outs FHE_AnyEncryptedInteger); let builders = [ OpBuilder<(ins "Value":$a, "Value":$b), [{ @@ -149,27 +154,28 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> { } def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> { - - let summary = "Substract a clear integer from an encrypted integer"; + let summary = "Subtract a clear integer from an encrypted integer"; let description = [{ - Substract a clear integer from an encrypted integer. - The clear integer must have at most one more bit than the encrypted integer - and the result must have the same width than the encrypted integer. + Subtract a clear integer from an encrypted integer. + The clear integer must have one more bit than the encrypted integer + and the result must have the same width and the same signedness as the encrypted integer. Example: ```mlir // ok "FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + "FHE.sub_eint_int"(%i, %a) : (!FHE.esint<2>, i3) -> !FHE.esint<2> // error "FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i4) -> !FHE.eint<2> "FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<3> + "FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.esint<2> ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$a, AnyInteger:$b); - let results = (outs FHE_EncryptedIntegerType); + let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b); + let results = (outs FHE_AnyEncryptedInteger); let builders = [ OpBuilder<(ins "Value":$a, "Value":$b), [{ @@ -178,31 +184,32 @@ def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> { ]; let hasVerifier = 1; - let hasFolder = 1; } def FHE_SubEintOp : FHE_Op<"sub_eint", [NoSideEffect]> { - - let summary = "Subtracts two encrypted integers"; + let summary = "Subtract an encrypted integer from an encrypted integer"; let description = [{ - Subtracts two encrypted integers - The encrypted integers and the result must have the same width. + Subtract an encrypted integer from an encrypted integer. + The encrypted integers and the result must have the same width and the same signedness. Example: ```mlir // ok "FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + "FHE.sub_eint"(%a, %b): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>) // error "FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>) "FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>) + "FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.esint<2>) -> (!FHE.esint<2>) + "FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>) ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$a, FHE_EncryptedIntegerType:$b); - let results = (outs FHE_EncryptedIntegerType); + let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b); + let results = (outs FHE_AnyEncryptedInteger); let builders = [ OpBuilder<(ins "Value":$a, "Value":$b), [{ @@ -219,20 +226,22 @@ def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> { let description = [{ Negates an encrypted integer. - The result must have the same width than the encrypted integer. + The result must have the same width and the same signedness as the encrypted integer. Example: ```mlir // ok "FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.eint<2>) + "FHE.neg_eint"(%a): (!FHE.esint<2>) -> (!FHE.esint<2>) // error "FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.eint<3>) + "FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.esint<2>) ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$a); - let results = (outs FHE_EncryptedIntegerType); + let arguments = (ins FHE_AnyEncryptedInteger:$a); + let results = (outs FHE_AnyEncryptedInteger); let builders = [ OpBuilder<(ins "Value":$a), [{ @@ -243,27 +252,28 @@ def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> { } def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> { - - let summary = "Mulitplies an encrypted integer and a clear integer"; + let summary = "Multiply an encrypted integer with a clear integer"; let description = [{ - Mulitplies an encrypted integer and a clear integer. - The clear integer must have at most one more bit than the encrypted integer - and the result must have the same width than the encrypted integer. + Multiply an encrypted integer with a clear integer. + The clear integer must have one more bit than the encrypted integer + and the result must have the same width and the same signedness as the encrypted integer. Example: ```mlir // ok "FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + "FHE.mul_eint_int"(%a, %i) : (!FHE.esint<2>, i3) -> !FHE.esint<2> // error "FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i4) -> !FHE.eint<2> "FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<3> + "FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.esint<2> ``` }]; - let arguments = (ins FHE_EncryptedIntegerType:$a, AnyInteger:$b); - let results = (outs FHE_EncryptedIntegerType); + let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b); + let results = (outs FHE_AnyEncryptedInteger); let builders = [ OpBuilder<(ins "Value":$a, "Value":$b), [{ @@ -275,6 +285,56 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> { let hasFolder = 1; } +def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> { + let summary = "Cast an unsigned integer to a signed one"; + + let description = [{ + Cast an unsigned integer to a signed one. + The result must have the same width as the input. + + The behavior is undefined on overflow/underflow. + + Examples: + ```mlir + // ok + "FHE.to_signed"(%x) : (!FHE.eint<2>) -> !FHE.esint<2> + + // error + "FHE.to_signed"(%x) : (!FHE.eint<2>) -> !FHE.esint<3> + ``` + }]; + + let arguments = (ins FHE_EncryptedIntegerType:$input); + let results = (outs FHE_EncryptedSignedIntegerType); + + let hasVerifier = 1; +} + +def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [NoSideEffect]> { + let summary = "Cast a signed integer to an unsigned one"; + + let description = [{ + Cast a signed integer to an unsigned one. + The result must have the same width as the input. + + The behavior is undefined on overflow/underflow. + + Examples: + ```mlir + // ok + "FHE.to_unsigned"(%x) : (!FHE.esint<2>) -> !FHE.eint<2> + + // error + "FHE.to_unsigned"(%x) : (!FHE.esint<2>) -> !FHE.eint<3> + ``` + }]; + + let arguments = (ins FHE_EncryptedSignedIntegerType:$input); + let results = (outs FHE_EncryptedIntegerType); + + let hasVerifier = 1; +} + def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> { let summary = "Applies a clear lookup table to an encrypted integer"; diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.h b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.h index 0d7a3687e..9db4a81c8 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.h +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.h @@ -11,6 +11,8 @@ #include #include +#include "concretelang/Dialect/FHE/IR/FHETypesInterfaces.h.inc" + #define GET_TYPEDEF_CLASSES #include "concretelang/Dialect/FHE/IR/FHEOpsTypes.h.inc" diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td index 2fa4d6993..815d72ed4 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHETypes.td @@ -2,13 +2,14 @@ #define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES include "concretelang/Dialect/FHE/IR/FHEDialect.td" +include "concretelang/Dialect/FHE/IR/FHEInterfaces.td" include "mlir/IR/BuiltinTypes.td" class FHE_Type traits = []> : TypeDef { } def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger", - [MemRefElementTypeInterface]> { + [MemRefElementTypeInterface, FheIntegerInterface]> { let mnemonic = "eint"; let summary = "An encrypted integer"; @@ -28,6 +29,44 @@ def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger", let hasCustomAssemblyFormat = 1; let genVerifyDecl = true; + + let extraClassDeclaration = [{ + bool isSigned() const { return false; } + bool isUnsigned() const { return true; } + }]; } +def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger", + [MemRefElementTypeInterface, FheIntegerInterface]> { + let mnemonic = "esint"; + + let summary = "An encrypted signed integer"; + + let description = [{ + An encrypted signed integer with `width` bits to performs FHE Operations. + + Examples: + ```mlir + !FHE.esint<7> + !FHE.esint<6> + ``` + }]; + + let parameters = (ins "unsigned":$width); + + let hasCustomAssemblyFormat = 1; + + let genVerifyDecl = true; + + let extraClassDeclaration = [{ + bool isSigned() const { return true; } + bool isUnsigned() const { return false; } + }]; +} + +def FHE_AnyEncryptedInteger : Type>; + #endif diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index d1a5a1983..5b9439155 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -22,8 +22,8 @@ def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRul let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers."; let description = [{ - Performs an addition follwing the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers. - The width of the clear integers must be less than or equals to the witdh of encrypted integers. + Performs an addition following the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers. + The width of the clear integers must be less than or equals to the width of encrypted integers. Examples: ```mlir @@ -58,11 +58,11 @@ def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRul }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$lhs, Type.predicate, HasStaticShapePred]>>:$rhs ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let builders = [ OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{ @@ -77,7 +77,7 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, Ten let summary = "Returns a tensor that contains the addition of two tensor of encrypted integers."; let description = [{ - Performs an addition follwing the broadcasting rules between two tensors of encrypted integers. + Performs an addition following the broadcasting rules between two tensors of encrypted integers. The width of the encrypted integers must be equals. Examples: @@ -112,11 +112,11 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, Ten }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$lhs, - Type.predicate, HasStaticShapePred]>>:$rhs + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let builders = [ OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{ @@ -126,21 +126,21 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, Ten } def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, TensorBinaryIntEint]> { - let summary = "Returns a tensor that contains the substraction of a tensor of clear integers and a tensor of encrypted integers."; + let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers and a tensor of encrypted integers."; let description = [{ - Performs a substraction following the broadcasting rules between a tensor of clear integers and a tensor of encrypted integers. - The width of the clear integers must be less than or equals to the witdh of encrypted integers. + Performs a subtraction following the broadcasting rules between a tensor of clear integers and a tensor of encrypted integers. + The width of the clear integers must be less than or equals to the width of encrypted integers. Examples: ```mlir - // Returns the term to term substraction of `%a0` with `%a1` + // Returns the term to term subtraction of `%a0` with `%a1` "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> - // Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched. + // Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched. "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>> - // Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. + // Returns the subtraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. // // [1,2,3] [1] [0,2,3] // [4,5,6] - [2] = [2,3,4] @@ -149,7 +149,7 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRul // The dimension #1 of operand #2 is stretched as it is equals to 1. "FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> - // Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. + // Returns the subtraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. // // [1,2,3] [0,0,0] // [4,5,6] - [1,2,3] = [3,3,3] @@ -166,10 +166,10 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRul let arguments = (ins Type.predicate, HasStaticShapePred]>>:$lhs, - Type.predicate, HasStaticShapePred]>>:$rhs + Type.predicate, HasStaticShapePred]>>:$rhs ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let builders = [ OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{ @@ -179,21 +179,21 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRul } def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> { - let summary = "Returns a tensor that contains the substraction of a tensor of clear integers from a tensor of encrypted integers."; + let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers from a tensor of encrypted integers."; let description = [{ - Performs a substraction following the broadcasting rules between a tensor of clear integers from a tensor of encrypted integers. - The width of the clear integers must be less than or equals to the witdh of encrypted integers. + Performs a subtraction following the broadcasting rules between a tensor of clear integers from a tensor of encrypted integers. + The width of the clear integers must be less than or equals to the width of encrypted integers. Examples: ```mlir - // Returns the term to term substraction of `%a0` with `%a1` + // Returns the term to term subtraction of `%a0` with `%a1` "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>> - // Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched. + // Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched. "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<1x4x4x!FHE.eint<4>>, tensor<4x1x4xi5>) -> tensor<4x4x4x!FHE.eint<4>> - // Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. + // Returns the subtraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. // // [1,2,3] [1] [0,2,3] // [4,5,6] - [2] = [2,3,4] @@ -202,7 +202,7 @@ def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRul // The dimension #1 of operand #2 is stretched as it is equals to 1. "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x1x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>> - // Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. + // Returns the subtraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. // // [1,2,3] [0,0,0] // [4,5,6] - [1,2,3] = [3,3,3] @@ -218,11 +218,11 @@ def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRul }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$lhs, Type.predicate, HasStaticShapePred]>>:$rhs ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let builders = [ OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{ @@ -238,7 +238,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten let summary = "Returns a tensor that contains the subtraction of two tensor of encrypted integers."; let description = [{ - Performs an subtraction follwing the broadcasting rules between two tensors of encrypted integers. + Performs an subtraction following the broadcasting rules between two tensors of encrypted integers. The width of the encrypted integers must be equal. Examples: @@ -249,7 +249,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten // Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched. "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>> - // Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. + // Returns the subtraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. // // [1,2,3] [1] [0,2,3] // [4,5,6] - [2] = [2,3,4] @@ -258,7 +258,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten // The dimension #1 of operand #2 is stretched as it is equals to 1. "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> - // Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. + // Returns the subtraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. // // [1,2,3] [0,0,0] // [4,5,6] - [1,2,3] = [3,3,3] @@ -273,11 +273,11 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$lhs, - Type.predicate, HasStaticShapePred]>>:$rhs + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let builders = [ OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{ @@ -306,10 +306,10 @@ def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> { }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$tensor + Type.predicate, HasStaticShapePred]>>:$tensor ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let builders = [ OpBuilder<(ins "Value":$tensor), [{ @@ -323,7 +323,7 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRul let description = [{ Performs a multiplication following the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers. - The width of the clear integers must be less than or equals to the witdh of encrypted integers. + The width of the clear integers must be less than or equals to the width of encrypted integers. Examples: ```mlir @@ -358,11 +358,11 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRul }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$lhs, Type.predicate, HasStaticShapePred]>>:$rhs ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let hasFolder = 1; } @@ -394,11 +394,11 @@ def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> { }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$t, + Type.predicate, HasStaticShapePred]>>:$t, Type.predicate, HasStaticShapePred]>>:$lut ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let hasVerifier = 1; } @@ -519,10 +519,10 @@ def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int"> { }]; let arguments = (ins - Type.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$lhs, + Type.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$lhs, Type.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs); - let results = (outs FHE_EncryptedIntegerType:$out); + let results = (outs FHE_AnyEncryptedInteger:$out); let hasVerifier = 1; } @@ -656,11 +656,11 @@ def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEin }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$lhs, Type.predicate, HasStaticShapePred]>>:$rhs ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let hasVerifier = 1; } @@ -795,10 +795,10 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt let arguments = (ins Type.predicate, HasStaticShapePred]>>:$lhs, - Type.predicate, HasStaticShapePred]>>:$rhs + Type.predicate, HasStaticShapePred]>>:$rhs ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let hasVerifier = 1; } @@ -871,15 +871,15 @@ def FHELinalg_SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> { }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$tensor, + Type.predicate, HasStaticShapePred]>>:$tensor, DefaultValuedAttr:$axes, DefaultValuedAttr:$keep_dims ); let results = (outs TypeConstraint.predicate, HasStaticShapePred]> + FHE_AnyEncryptedInteger.predicate, + And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]> ]>>:$out ); @@ -917,12 +917,12 @@ def FHELinalg_ConcatOp : FHELinalg_Op<"concat"> { }]; let arguments = (ins - Variadic.predicate, HasStaticShapePred]>>>:$ins, + Variadic.predicate, HasStaticShapePred]>>>:$ins, DefaultValuedAttr:$axis ); let results = (outs - Type.predicate, HasStaticShapePred]>>:$out + Type.predicate, HasStaticShapePred]>>:$out ); let hasVerifier = 1; @@ -931,7 +931,7 @@ def FHELinalg_ConcatOp : FHELinalg_Op<"concat"> { def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> { let summary = "Returns the 2D convolution of a tensor in the form NCHW with weights in the form FCHW"; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$input, + Type.predicate, HasStaticShapePred]>>:$input, Type.predicate, HasStaticShapePred]>>:$weight, Optional.predicate, HasStaticShapePred]>>>:$bias, // Since there is no U64ElementsAttr, we use I64 and make sure there is no neg values during verification @@ -940,7 +940,7 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> { OptionalAttr:$dilations, OptionalAttr:$group ); - let results = (outs Type.predicate, HasStaticShapePred]>>); + let results = (outs Type.predicate, HasStaticShapePred]>>); let hasVerifier = 1; } @@ -989,4 +989,64 @@ def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", []> { let hasVerifier = 1; } +def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", []> { + let summary = "Cast an unsigned integer tensor to a signed one"; + + let description = [{ + Cast an unsigned integer tensor to a signed one. + The result must have the same width and the same shape as the input. + + The behavior is undefined on overflow/underflow. + + Examples: + ```mlir + // ok + "FHELinalg.to_signed"(%x) : (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>> + + // error + "FHELinalg.to_signed"(%x) : (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<3>> + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$input + ); + + let results = (outs + Type.predicate, HasStaticShapePred]>>:$output + ); + + let hasVerifier = 1; +} + +def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", []> { + let summary = "Cast a signed integer tensor to an unsigned one"; + + let description = [{ + Cast a signed integer tensor to an unsigned one. + The result must have the same width and the same shape as the input. + + The behavior is undefined on overflow/underflow. + + Examples: + ```mlir + // ok + "FHELinalg.to_unsigned"(%x) : (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>> + + // error + "FHELinalg.to_unsigned"(%x) : (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<3>> + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$input + ); + + let results = (outs + Type.predicate, HasStaticShapePred]>>:$output + ); + + let hasVerifier = 1; +} + #endif diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 4440837de..8975db510 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1775,6 +1775,184 @@ struct FHELinalgConv2dToLinalgConv2d }; }; +/// This template rewrite pattern transforms any instance of +/// operators `FHELinalg.to_signed` to an instance of `linalg.generic` with an +/// appropriate region using `FHE.to_signed` operation, an appropriate +/// specification for the iteration dimensions and appropriate operations +/// managing the accumulator of `linalg.generic`. +/// +/// Example: +/// +/// FHELinalg.to_signed(%tensor): +/// tensor> -> tensor> +/// +/// becomes: +/// +/// #maps = [ +/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>, +/// affine_map<(aN, ..., a1) -> (aN, ..., a1)> +/// ] +/// #attributes { +/// indexing_maps = #maps, +/// iterator_types = ["parallel", "parallel"], +/// } +/// +/// %init = linalg.init_tensor [DN,...,D1] : tensor> +/// %result = linalg.generic { +/// ins(%tensor: tensor>) +/// outs(%init: tensor>) +/// { +/// ^bb0(%arg0: !FHE.eint

): +/// %0 = FHE.to_signed(%arg0): !FHE.eint

-> !FHE.esint

+/// linalg.yield %0 : !FHE.esint

+/// } +/// } +/// +struct FHELinalgToSignedToLinalgGeneric + : public mlir::OpRewritePattern { + FHELinalgToSignedToLinalgGeneric( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(FHELinalg::ToSignedOp op, + mlir::PatternRewriter &rewriter) const override { + + mlir::RankedTensorType inputTy = + op.input().getType().cast(); + mlir::RankedTensorType resultTy = + op->getResult(0).getType().cast(); + + mlir::Value init = rewriter.create( + op.getLoc(), resultTy, mlir::ValueRange{}); + + llvm::SmallVector maps{ + mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(), + this->getContext()), + mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), + this->getContext()), + }; + + llvm::SmallVector iteratorTypes(resultTy.getShape().size(), + "parallel"); + + auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + auto fheOp = nestedBuilder.create( + op.getLoc(), resultTy.getElementType(), blockArgs[0]); + + nestedBuilder.create(op.getLoc(), + fheOp.getResult()); + }; + + llvm::SmallVector resTypes{init.getType()}; + llvm::SmallVector ins{op.input()}; + llvm::SmallVector outs{init}; + + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + auto genericOp = rewriter.create( + op.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, + bodyBuilder); + + rewriter.replaceOp(op, {genericOp.getResult(0)}); + return mlir::success(); + }; +}; + +/// This template rewrite pattern transforms any instance of +/// operators `FHELinalg.to_unsigned` to an instance of `linalg.generic` with an +/// appropriate region using `FHE.to_unsigned` operation, an appropriate +/// specification for the iteration dimensions and appropriate operations +/// managing the accumulator of `linalg.generic`. +/// +/// Example: +/// +/// FHELinalg.to_unsigned(%tensor): +/// tensor> -> tensor> +/// +/// becomes: +/// +/// #maps = [ +/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>, +/// affine_map<(aN, ..., a1) -> (aN, ..., a1)> +/// ] +/// #attributes { +/// indexing_maps = #maps, +/// iterator_types = ["parallel", "parallel"], +/// } +/// +/// %init = linalg.init_tensor [DN,...,D1] : tensor> +/// %result = linalg.generic { +/// ins(%tensor: tensor>) +/// outs(%init: tensor>) +/// { +/// ^bb0(%arg0: !FHE.esint

): +/// %0 = FHE.to_unsigned(%arg0): !FHE.esint

-> !FHE.eint

+/// linalg.yield %0 : !FHE.eint

+/// } +/// } +/// +struct FHELinalgToUnsignedToLinalgGeneric + : public mlir::OpRewritePattern { + FHELinalgToUnsignedToLinalgGeneric( + mlir::MLIRContext *context, + mlir::PatternBenefit benefit = + mlir::concretelang::DEFAULT_PATTERN_BENEFIT) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(FHELinalg::ToUnsignedOp op, + mlir::PatternRewriter &rewriter) const override { + + mlir::RankedTensorType inputTy = + op.input().getType().cast(); + mlir::RankedTensorType resultTy = + op->getResult(0).getType().cast(); + + mlir::Value init = rewriter.create( + op.getLoc(), resultTy, mlir::ValueRange{}); + + llvm::SmallVector maps{ + mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(), + this->getContext()), + mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(), + this->getContext()), + }; + + llvm::SmallVector iteratorTypes(resultTy.getShape().size(), + "parallel"); + + auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + auto fheOp = nestedBuilder.create( + op.getLoc(), resultTy.getElementType(), blockArgs[0]); + + nestedBuilder.create(op.getLoc(), + fheOp.getResult()); + }; + + llvm::SmallVector resTypes{init.getType()}; + llvm::SmallVector ins{op.input()}; + llvm::SmallVector outs{init}; + + llvm::StringRef doc{""}; + llvm::StringRef call{""}; + + auto genericOp = rewriter.create( + op.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call, + bodyBuilder); + + rewriter.replaceOp(op, {genericOp.getResult(0)}); + return mlir::success(); + }; +}; + namespace { struct FHETensorOpsToLinalg : public FHETensorOpsToLinalgBase { @@ -1847,6 +2025,8 @@ void FHETensorOpsToLinalg::runOnOperation() { patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp b/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp index 4cf84fd5f..3bea0d3a9 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEDialect.cpp @@ -7,6 +7,8 @@ #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" +#include "concretelang/Dialect/FHE/IR/FHETypesInterfaces.cpp.inc" + #define GET_TYPEDEF_CLASSES #include "concretelang/Dialect/FHE/IR/FHEOpsTypes.cpp.inc" @@ -31,7 +33,7 @@ void FHEDialect::initialize() { mlir::LogicalResult EncryptedIntegerType::verify( llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) { if (p == 0) { - emitError() << "FHE.eint didn't support precision equals to 0"; + emitError() << "FHE.eint doesn't support precision of 0"; return mlir::failure(); } return mlir::success(); @@ -57,3 +59,33 @@ mlir::Type EncryptedIntegerType::parse(mlir::AsmParser &p) { return getChecked(loc, loc.getContext(), width); } + +mlir::LogicalResult EncryptedSignedIntegerType::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) { + if (p == 0) { + emitError() << "FHE.esint doesn't support precision of 0"; + return mlir::failure(); + } + return mlir::success(); +} + +void EncryptedSignedIntegerType::print(mlir::AsmPrinter &p) const { + p << "<" << getWidth() << ">"; +} + +mlir::Type EncryptedSignedIntegerType::parse(mlir::AsmParser &p) { + if (p.parseLess()) + return mlir::Type(); + + int width; + + if (p.parseInteger(width)) + return mlir::Type(); + + if (p.parseGreater()) + return mlir::Type(); + + mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc()); + + return getChecked(loc, loc.getContext(), width); +} diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index afd69ae04..dd26a83d0 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -14,113 +14,144 @@ namespace concretelang { namespace FHE { bool verifyEncryptedIntegerInputAndResultConsistency( - ::mlir::Operation &op, EncryptedIntegerType &input, - EncryptedIntegerType &result) { - if (input.getWidth() != result.getWidth()) { + mlir::Operation &op, FheIntegerInterface &input, + FheIntegerInterface &result) { + + if (input.isSigned() != result.isSigned()) { op.emitOpError( - " should have the width of encrypted inputs and result equals"); + "should have the signedness of encrypted inputs and result equal"); return false; } + + if (input.getWidth() != result.getWidth()) { + op.emitOpError( + "should have the width of encrypted inputs and result equal"); + return false; + } + return true; } -bool verifyEncryptedIntegerAndIntegerInputsConsistency(::mlir::Operation &op, - EncryptedIntegerType &a, +bool verifyEncryptedIntegerAndIntegerInputsConsistency(mlir::Operation &op, + FheIntegerInterface &a, IntegerType &b) { + if (a.getWidth() + 1 != b.getWidth()) { - op.emitOpError(" should have the width of plain input equals to width of " + op.emitOpError("should have the width of plain input equal to width of " "encrypted input + 1"); return false; } + return true; } -bool verifyEncryptedIntegerInputsConsistency(::mlir::Operation &op, - EncryptedIntegerType &a, - EncryptedIntegerType &b) { - if (a.getWidth() != b.getWidth()) { - op.emitOpError(" should have the width of encrypted inputs equals"); +bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op, + FheIntegerInterface &a, + FheIntegerInterface &b) { + if (a.isSigned() != b.isSigned()) { + op.emitOpError("should have the signedness of encrypted inputs equal"); return false; } + + if (a.getWidth() != b.getWidth()) { + op.emitOpError("should have the width of encrypted inputs equal"); + return false; + } + return true; } -::mlir::LogicalResult AddEintIntOp::verify() { - auto a = this->a().getType().cast(); +mlir::LogicalResult AddEintIntOp::verify() { + auto a = this->a().getType().dyn_cast(); auto b = this->b().getType().cast(); - auto out = this->getResult().getType().cast(); + auto out = this->getResult().getType().dyn_cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, out)) { - return ::mlir::failure(); + return mlir::failure(); } + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(), a, b)) { - return ::mlir::failure(); + return mlir::failure(); } - return ::mlir::success(); + + return mlir::success(); } -::mlir::LogicalResult AddEintOp::verify() { - auto a = this->a().getType().cast(); - auto b = this->b().getType().cast(); - auto out = this->getResult().getType().cast(); +mlir::LogicalResult AddEintOp::verify() { + auto a = this->a().getType().dyn_cast(); + auto b = this->b().getType().dyn_cast(); + auto out = this->getResult().getType().dyn_cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, out)) { return ::mlir::failure(); } + if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) { return ::mlir::failure(); } + return ::mlir::success(); } -::mlir::LogicalResult SubIntEintOp::verify() { +mlir::LogicalResult SubIntEintOp::verify() { auto a = this->a().getType().cast(); - auto b = this->b().getType().cast(); - auto out = this->getResult().getType().cast(); + auto b = this->b().getType().dyn_cast(); + auto out = this->getResult().getType().dyn_cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), b, out)) { - return ::mlir::failure(); + return mlir::failure(); } + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(), b, a)) { - return ::mlir::failure(); + return mlir::failure(); } - return ::mlir::success(); + + return mlir::success(); } -::mlir::LogicalResult SubEintIntOp::verify() { - auto a = this->a().getType().cast(); +mlir::LogicalResult SubEintIntOp::verify() { + auto a = this->a().getType().dyn_cast(); auto b = this->b().getType().cast(); - auto out = this->getResult().getType().cast(); + auto out = this->getResult().getType().dyn_cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, out)) { - return ::mlir::failure(); + return mlir::failure(); } + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(), a, b)) { - return ::mlir::failure(); + return mlir::failure(); } - return ::mlir::success(); + + return mlir::success(); } -::mlir::LogicalResult SubEintOp::verify() { - auto a = this->a().getType().cast(); - auto b = this->b().getType().cast(); - auto out = this->getResult().getType().cast(); +mlir::LogicalResult SubEintOp::verify() { + auto a = this->a().getType().dyn_cast(); + auto b = this->b().getType().dyn_cast(); + auto out = this->getResult().getType().dyn_cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, out)) { return ::mlir::failure(); } + if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) { return ::mlir::failure(); } + return ::mlir::success(); } -::mlir::LogicalResult NegEintOp::verify() { - auto a = this->a().getType().cast(); - auto out = this->getResult().getType().cast(); +mlir::LogicalResult NegEintOp::verify() { + auto a = this->a().getType().dyn_cast(); + auto out = this->getResult().getType().dyn_cast(); if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, out)) { return ::mlir::failure(); @@ -128,19 +159,48 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::Operation &op, return ::mlir::success(); } -::mlir::LogicalResult MulEintIntOp::verify() { - auto a = this->a().getType().cast(); +mlir::LogicalResult MulEintIntOp::verify() { + auto a = this->a().getType().dyn_cast(); auto b = this->b().getType().cast(); - auto out = this->getResult().getType().cast(); + auto out = this->getResult().getType().dyn_cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, out)) { - return ::mlir::failure(); + return mlir::failure(); } + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(), a, b)) { - return ::mlir::failure(); + return mlir::failure(); } - return ::mlir::success(); + + return mlir::success(); +} + +mlir::LogicalResult ToSignedOp::verify() { + auto input = this->input().getType().cast(); + auto output = this->getResult().getType().cast(); + + if (input.getWidth() != output.getWidth()) { + this->emitOpError( + "should have the width of encrypted input and result equal"); + return mlir::failure(); + } + + return mlir::success(); +} + +mlir::LogicalResult ToUnsignedOp::verify() { + auto input = this->input().getType().cast(); + auto output = this->getResult().getType().cast(); + + if (input.getWidth() != output.getWidth()) { + this->emitOpError( + "should have the width of encrypted input and result equal"); + return mlir::failure(); + } + + return mlir::success(); } ::mlir::LogicalResult ApplyLookupTableEintOp::verify() { diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 5a8236569..b8cd724d0 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -110,32 +110,38 @@ LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } + auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } + auto el0Ty = op0Ty.getElementType() - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (el0Ty == nullptr) { - op->emitOpError() << "should have a !FHE.eint as the element type of the " - "tensor of operand #0"; + op->emitOpError() + << "should have !FHE.eint or !FHE.esint as the element type of the " + "tensor of operand #0"; return mlir::failure(); } + auto el1Ty = op1Ty.getElementType().dyn_cast_or_null(); if (el1Ty == nullptr) { op->emitOpError() << "should have an integer as the element type of the " "tensor of operand #1"; return mlir::failure(); } + if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { op->emitOpError() << "should have the width of integer values less or equals " "than the width of encrypted values + 1"; return mlir::failure(); } + return mlir::success(); } @@ -144,32 +150,38 @@ LogicalResult verifyTensorBinaryIntEint(mlir::Operation *op) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } + auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } + auto el0Ty = op0Ty.getElementType().dyn_cast_or_null(); if (el0Ty == nullptr) { op->emitOpError() << "should have an integer as the element type of the " "tensor of operand #0"; return mlir::failure(); } + auto el1Ty = op1Ty.getElementType() - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (el1Ty == nullptr) { - op->emitOpError() << "should have a !FHE.eint as the element type of the " - "tensor of operand #1"; + op->emitOpError() + << "should have !FHE.eint or !FHE.esint as the element type of the " + "tensor of operand #1"; return mlir::failure(); } + if (el1Ty.getWidth() > el0Ty.getWidth() + 1) { op->emitOpError() << "should have the width of integer values less or equals " "than the width of encrypted values + 1"; return mlir::failure(); } + return mlir::success(); } @@ -178,34 +190,50 @@ LogicalResult verifyTensorBinaryEint(mlir::Operation *op) { op->emitOpError() << "should have exactly 2 operands"; return mlir::failure(); } + auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null(); if (op0Ty == nullptr || op1Ty == nullptr) { op->emitOpError() << "should have both operands as tensor"; return mlir::failure(); } + auto el0Ty = op0Ty.getElementType() - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (el0Ty == nullptr) { - op->emitOpError() << "should have a !FHE.eint as the element type of the " - "tensor of operand #0"; + op->emitOpError() + << "should have !FHE.eint or !FHE.esint as the element type of the " + "tensor of operand #0"; return mlir::failure(); } + auto el1Ty = op1Ty.getElementType() - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (el1Ty == nullptr) { - op->emitOpError() << "should have a !FHE.eint as the element type of the " - "tensor of operand #1"; + op->emitOpError() + << "should have !FHE.eint or !FHE.esint as the element type of the " + "tensor of operand #1"; return mlir::failure(); } - if (el1Ty.getWidth() != el0Ty.getWidth()) { + + if (el0Ty.isSigned() != el1Ty.isSigned()) { + op->emitOpError() + << "should have the signedness of encrypted arguments equal"; + return mlir::failure(); + } + + unsigned el0BitWidth = el0Ty.getWidth(); + unsigned el1BitWidth = el1Ty.getWidth(); + + if (el1BitWidth != el0BitWidth) { op->emitOpError() << "should have the width of encrypted equals" ", got " - << el1Ty.getWidth() << " expect " << el0Ty.getWidth(); + << el1BitWidth << " expect " << el0BitWidth; return mlir::failure(); } + return mlir::success(); } @@ -214,19 +242,23 @@ LogicalResult verifyTensorUnaryEint(mlir::Operation *op) { op->emitOpError() << "should have exactly 1 operands"; return mlir::failure(); } + auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null(); if (op0Ty == nullptr) { op->emitOpError() << "should have operand as tensor"; return mlir::failure(); } + auto el0Ty = op0Ty.getElementType() - .dyn_cast_or_null(); + .dyn_cast_or_null(); if (el0Ty == nullptr) { - op->emitOpError() << "should have a !FHE.eint as the element type of the " - "tensor operand"; + op->emitOpError() + << "should have !FHE.eint or !FHE.esint as the element type of the " + "tensor operand"; return mlir::failure(); } + return mlir::success(); } @@ -377,14 +409,14 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() { .getType() .cast() .getElementType() - .cast(); + .dyn_cast(); auto rhsEltType = this->rhs() .getType() .cast() .getElementType() .cast(); auto resultType = - this->getResult().getType().cast(); + this->getResult().getType().dyn_cast(); if (!mlir::concretelang::FHE:: verifyEncryptedIntegerAndIntegerInputsConsistency( *this->getOperation(), lhsEltType, rhsEltType)) { @@ -430,16 +462,15 @@ mlir::LogicalResult SumOp::verify() { mlir::Value output = this->getResult(); auto inputType = input.getType().dyn_cast(); - mlir::Type outputType = output.getType(); + Type outputType = output.getType(); - FHE::EncryptedIntegerType inputElementType = - inputType.getElementType().dyn_cast(); - FHE::EncryptedIntegerType outputElementType = - !outputType.isa() - ? outputType.dyn_cast() - : outputType.dyn_cast() - .getElementType() - .dyn_cast(); + auto inputElementType = + inputType.getElementType().dyn_cast(); + auto outputElementType = !outputType.isa() + ? outputType.dyn_cast() + : outputType.dyn_cast() + .getElementType() + .dyn_cast(); if (!FHE::verifyEncryptedIntegerInputAndResultConsistency( *this->getOperation(), inputElementType, outputElementType)) { @@ -517,7 +548,7 @@ mlir::LogicalResult ConcatOp::verify() { auto outVectorType = out.getType().dyn_cast(); auto outElementType = - outVectorType.getElementType().dyn_cast(); + outVectorType.getElementType().dyn_cast(); llvm::ArrayRef outShape = outVectorType.getShape(); size_t outDims = outShape.size(); @@ -533,7 +564,7 @@ mlir::LogicalResult ConcatOp::verify() { for (mlir::Value in : this->ins()) { auto inVectorType = in.getType().dyn_cast(); auto inElementType = - inVectorType.getElementType().dyn_cast(); + inVectorType.getElementType().dyn_cast(); if (!FHE::verifyEncryptedIntegerInputAndResultConsistency( *this->getOperation(), inElementType, outElementType)) { return ::mlir::failure(); @@ -827,9 +858,11 @@ mlir::LogicalResult Conv2dOp::verify() { auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); - auto p = inputTy.getElementType() - .cast() - .getWidth(); + Type inputElTy = inputTy.getElementType(); + auto p = inputElTy.isa() + ? inputElTy.cast().getWidth() + : inputElTy.cast().getWidth(); + auto weightElementTyWidth = weightTy.getElementType().cast().getWidth(); if (weightElementTyWidth != p + 1) { @@ -1068,6 +1101,62 @@ mlir::LogicalResult TransposeOp::verify() { return mlir::success(); } +mlir::LogicalResult ToSignedOp::verify() { + auto inputType = this->input().getType().cast(); + auto outputType = this->getResult().getType().cast(); + + llvm::ArrayRef inputShape = inputType.getShape(); + llvm::ArrayRef outputShape = outputType.getShape(); + + if (inputShape != outputShape) { + this->emitOpError() + << "input and output tensors should have the same shape"; + return mlir::failure(); + } + + auto inputElementType = + inputType.getElementType().cast(); + auto outputElementType = + outputType.getElementType().cast(); + + if (inputElementType.getWidth() != outputElementType.getWidth()) { + this->emitOpError() + << "input and output tensors should have the same width"; + return mlir::failure(); + } + + return mlir::success(); +} + +mlir::LogicalResult ToUnsignedOp::verify() { + mlir::ShapedType inputType = + this->input().getType().dyn_cast_or_null(); + mlir::ShapedType outputType = + this->getResult().getType().dyn_cast_or_null(); + + llvm::ArrayRef inputShape = inputType.getShape(); + llvm::ArrayRef outputShape = outputType.getShape(); + + if (inputShape != outputShape) { + this->emitOpError() + << "input and output tensors should have the same shape"; + return mlir::failure(); + } + + auto inputElementType = + inputType.getElementType().cast(); + auto outputElementType = + outputType.getElementType().cast(); + + if (inputElementType.getWidth() != outputElementType.getWidth()) { + this->emitOpError() + << "input and output tensors should have the same width"; + return mlir::failure(); + } + + return mlir::success(); +} + /// Avoid addition with constant tensor of 0s OpFoldResult AddEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/to_signed.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/to_signed.mlir new file mode 100644 index 000000000..5ea13f4f7 --- /dev/null +++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/to_signed.mlir @@ -0,0 +1,18 @@ +// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-NEXT: module { +// CHECK-NEXT: func.func @main(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> { +// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<2>> +// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.eint<2>>) outs(%0 : tensor<2x3x4x!FHE.esint<2>>) { +// CHECK-NEXT: ^bb0(%arg1: !FHE.eint<2>, %arg2: !FHE.esint<2>): +// CHECK-NEXT: %2 = "FHE.to_signed"(%arg1) : (!FHE.eint<2>) -> !FHE.esint<2> +// CHECK-NEXT: linalg.yield %2 : !FHE.esint<2> +// CHECK-NEXT: } -> tensor<2x3x4x!FHE.esint<2>> +// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.esint<2>> +// CHECK-NEXT: } +// CHECK-NEXT: } +func.func @main(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> { + %1 = "FHELinalg.to_signed"(%arg0): (tensor<2x3x4x!FHE.eint<2>>) -> (tensor<2x3x4x!FHE.esint<2>>) + return %1: tensor<2x3x4x!FHE.esint<2>> +} diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/to_unsigned.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/to_unsigned.mlir new file mode 100644 index 000000000..8358e40fc --- /dev/null +++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/to_unsigned.mlir @@ -0,0 +1,18 @@ +// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-NEXT: module { +// CHECK-NEXT: func.func @main(%arg0: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> { +// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>> +// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.esint<2>>) outs(%0 : tensor<2x3x4x!FHE.eint<2>>) { +// CHECK-NEXT: ^bb0(%arg1: !FHE.esint<2>, %arg2: !FHE.eint<2>): +// CHECK-NEXT: %2 = "FHE.to_unsigned"(%arg1) : (!FHE.esint<2>) -> !FHE.eint<2> +// CHECK-NEXT: linalg.yield %2 : !FHE.eint<2> +// CHECK-NEXT: } -> tensor<2x3x4x!FHE.eint<2>> +// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.eint<2>> +// CHECK-NEXT: } +// CHECK-NEXT: } +func.func @main(%arg0: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> { + %1 = "FHELinalg.to_unsigned"(%arg0): (tensor<2x3x4x!FHE.esint<2>>) -> (tensor<2x3x4x!FHE.eint<2>>) + return %1: tensor<2x3x4x!FHE.eint<2>> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/add_eint.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/add_eint.invalid.mlir new file mode 100644 index 000000000..c585bc131 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/add_eint.invalid.mlir @@ -0,0 +1,31 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.add_eint' op should have the width of encrypted inputs equal +func.func @bad_inputs_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.add_eint' op should have the signedness of encrypted inputs equal +func.func @bad_inputs_signedness(%arg0: !FHE.eint<2>, %arg1: !FHE.esint<2>) -> !FHE.eint<2> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.esint<2>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.add_eint' op should have the width of encrypted inputs and result equal +func.func @bad_result_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.add_eint' op should have the signedness of encrypted inputs and result equal +func.func @bad_result_signedness(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.esint<2> { + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/add_eint_int.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/add_eint_int.invalid.mlir new file mode 100644 index 000000000..944aabf82 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/add_eint_int.invalid.mlir @@ -0,0 +1,26 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the width of plain input equal to width of encrypted input + 1 +func.func @bad_clear_width(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + %0 = arith.constant 1 : i4 + %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i4) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the width of encrypted inputs and result equal +func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.eint<3> { + %0 = arith.constant 1 : i3 + %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the signedness of encrypted inputs and result equal +func.func @bad_result_signedness(%arg0: !FHE.eint<2>) -> !FHE.esint<2> { + %0 = arith.constant 1 : i3 + %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_small.mlir b/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_small.mlir index 963daa63c..3571ae93d 100644 --- a/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_small.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_small.mlir @@ -1,6 +1,13 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s -// CHECK-LABEL: FHE.eint didn't support precision equals to 0 +// CHECK-LABEL: FHE.eint doesn't support precision of 0 func.func @test(%arg0: !FHE.eint<0>) { return } + +// ----- + +// CHECK-LABEL: FHE.esint doesn't support precision of 0 +func.func @test_signed(%arg0: !FHE.esint<0>) { + return +} diff --git a/compiler/tests/check_tests/Dialect/FHE/mul_eint_int.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/mul_eint_int.invalid.mlir new file mode 100644 index 000000000..d70736381 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/mul_eint_int.invalid.mlir @@ -0,0 +1,26 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the width of plain input equal to width of encrypted input + 1 +func.func @bad_clear_width(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + %0 = arith.constant 1 : i4 + %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i4) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the width of encrypted inputs and result equal +func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.eint<3> { + %0 = arith.constant 1 : i3 + %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the signedness of encrypted inputs and result equal +func.func @bad_result_signedness(%arg0: !FHE.eint<2>) -> !FHE.esint<2> { + %0 = arith.constant 1 : i3 + %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/neg_eint.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/neg_eint.invalid.mlir new file mode 100644 index 000000000..def238be3 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/neg_eint.invalid.mlir @@ -0,0 +1,15 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.neg_eint' op should have the width of encrypted inputs and result equal +func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.eint<3> { + %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<2>) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.neg_eint' op should have the signedness of encrypted inputs and result equal +func.func @bad_result_signedness(%arg0: !FHE.eint<2>) -> !FHE.esint<2> { + %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<2>) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_add_eint_err_inputs.mlir b/compiler/tests/check_tests/Dialect/FHE/op_add_eint_err_inputs.mlir deleted file mode 100644 index 49f4b54b6..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_add_eint_err_inputs.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.add_eint' op should have the width of encrypted inputs equals -func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> { - %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>) - return %1: !FHE.eint<2> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_add_eint_err_result.mlir b/compiler/tests/check_tests/Dialect/FHE/op_add_eint_err_result.mlir deleted file mode 100644 index 9644f51ce..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_add_eint_err_result.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.add_eint' op should have the width of encrypted inputs and result equals -func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> { - %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>) - return %1: !FHE.eint<3> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_add_eint_int_err_inputs.mlir b/compiler/tests/check_tests/Dialect/FHE/op_add_eint_int_err_inputs.mlir deleted file mode 100644 index 884994a5d..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_add_eint_int_err_inputs.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1 -func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { - %0 = arith.constant 1 : i4 - %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i4) -> (!FHE.eint<2>) - return %1: !FHE.eint<2> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_add_eint_int_err_result.mlir b/compiler/tests/check_tests/Dialect/FHE/op_add_eint_int_err_result.mlir deleted file mode 100644 index 7ee740557..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_add_eint_int_err_result.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the width of encrypted inputs and result equals -func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<3> { - %0 = arith.constant 1 : i2 - %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i2) -> (!FHE.eint<3>) - return %1: !FHE.eint<3> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_mul_eint_int_err_inputs.mlir b/compiler/tests/check_tests/Dialect/FHE/op_mul_eint_int_err_inputs.mlir deleted file mode 100644 index 0229bfcea..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_mul_eint_int_err_inputs.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1 -func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { - %0 = arith.constant 1 : i4 - %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i4) -> (!FHE.eint<2>) - return %1: !FHE.eint<2> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_mul_eint_int_err_result.mlir b/compiler/tests/check_tests/Dialect/FHE/op_mul_eint_int_err_result.mlir deleted file mode 100644 index fc535a225..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_mul_eint_int_err_result.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the width of encrypted inputs and result equals -func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<3> { - %0 = arith.constant 1 : i2 - %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i2) -> (!FHE.eint<3>) - return %1: !FHE.eint<3> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_neg_eint_err_result.mlir b/compiler/tests/check_tests/Dialect/FHE/op_neg_eint_err_result.mlir deleted file mode 100644 index cf88222f9..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_neg_eint_err_result.mlir +++ /dev/null @@ -1,7 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.neg_eint' op should have the width of encrypted inputs and result equals -func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<3> { - %1 = "FHE.neg_eint"(%arg0): (!FHE.eint<2>) -> (!FHE.eint<3>) - return %1: !FHE.eint<3> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_sub_int_eint_err_inputs.mlir b/compiler/tests/check_tests/Dialect/FHE/op_sub_int_eint_err_inputs.mlir deleted file mode 100644 index f28c50d2e..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_sub_int_eint_err_inputs.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1 -func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { - %0 = arith.constant 1 : i4 - %1 = "FHE.sub_int_eint"(%0, %arg0): (i4, !FHE.eint<2>) -> (!FHE.eint<2>) - return %1: !FHE.eint<2> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/op_sub_int_eint_err_result.mlir b/compiler/tests/check_tests/Dialect/FHE/op_sub_int_eint_err_result.mlir deleted file mode 100644 index 085a220d1..000000000 --- a/compiler/tests/check_tests/Dialect/FHE/op_sub_int_eint_err_result.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s - -// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the width of encrypted inputs and result equals -func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<3> { - %0 = arith.constant 1 : i2 - %1 = "FHE.sub_int_eint"(%0, %arg0): (i2, !FHE.eint<2>) -> (!FHE.eint<3>) - return %1: !FHE.eint<3> -} diff --git a/compiler/tests/check_tests/Dialect/FHE/ops.mlir b/compiler/tests/check_tests/Dialect/FHE/ops.mlir index 0807883a5..517b6d9f8 100644 --- a/compiler/tests/check_tests/Dialect/FHE/ops.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/ops.mlir @@ -9,6 +9,15 @@ func.func @zero() -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK: func.func @zero_signed() -> !FHE.esint<2> +func.func @zero_signed() -> !FHE.esint<2> { + // CHECK-NEXT: %[[RET:.*]] = "FHE.zero"() : () -> !FHE.esint<2> + // CHECK-NEXT: return %[[RET]] : !FHE.esint<2> + + %1 = "FHE.zero"() : () -> !FHE.esint<2> + return %1: !FHE.esint<2> +} + // CHECK: func.func @zero_1D() -> tensor<4x!FHE.eint<2>> { // CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<2>> // CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<2>> @@ -18,6 +27,15 @@ func.func @zero_1D() -> tensor<4x!FHE.eint<2>> { return %0 : tensor<4x!FHE.eint<2>> } +// CHECK: func.func @zero_1D_signed() -> tensor<4x!FHE.esint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x!FHE.esint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.esint<2>> +// CHECK-NEXT: } +func.func @zero_1D_signed() -> tensor<4x!FHE.esint<2>> { + %0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.esint<2>> + return %0 : tensor<4x!FHE.esint<2>> +} + // CHECK: func.func @zero_2D() -> tensor<4x9x!FHE.eint<2>> { // CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x9x!FHE.eint<2>> // CHECK-NEXT: return %[[v0]] : tensor<4x9x!FHE.eint<2>> @@ -27,6 +45,15 @@ func.func @zero_2D() -> tensor<4x9x!FHE.eint<2>> { return %0 : tensor<4x9x!FHE.eint<2>> } +// CHECK: func.func @zero_2D_signed() -> tensor<4x9x!FHE.esint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x9x!FHE.esint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<4x9x!FHE.esint<2>> +// CHECK-NEXT: } +func.func @zero_2D_signed() -> tensor<4x9x!FHE.esint<2>> { + %0 = "FHE.zero_tensor"() : () -> tensor<4x9x!FHE.esint<2>> + return %0 : tensor<4x9x!FHE.esint<2>> +} + // CHECK-LABEL: func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 @@ -38,6 +65,35 @@ func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK-LABEL: func.func @add_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> +func.func @add_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 + // CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%arg0, %[[V1]]) : (!FHE.esint<2>, i3) -> !FHE.esint<2> + // CHECK-NEXT: return %[[V2]] : !FHE.esint<2> + + %0 = arith.constant 1 : i3 + %1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.esint<2>, i3) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} + +// CHECK-LABEL: func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> +func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: return %[[V1]] : !FHE.eint<2> + + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// CHECK-LABEL: func.func @add_eint_signed(%arg0: !FHE.esint<2>, %arg1: !FHE.esint<2>) -> !FHE.esint<2> +func.func @add_eint_signed(%arg0: !FHE.esint<2>, %arg1: !FHE.esint<2>) -> !FHE.esint<2> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.esint<2>, !FHE.esint<2>) -> !FHE.esint<2> + // CHECK-NEXT: return %[[V1]] : !FHE.esint<2> + + %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} + // CHECK-LABEL: func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 @@ -49,6 +105,17 @@ func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK-LABEL: func.func @sub_int_eint_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> +func.func @sub_int_eint_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 + // CHECK-NEXT: %[[V2:.*]] = "FHE.sub_int_eint"(%[[V1]], %arg0) : (i3, !FHE.esint<2>) -> !FHE.esint<2> + // CHECK-NEXT: return %[[V2]] : !FHE.esint<2> + + %0 = arith.constant 1 : i3 + %1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.esint<2>) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} + // CHECK-LABEL: func.func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func.func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 @@ -60,6 +127,17 @@ func.func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK-LABEL: func.func @sub_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> +func.func @sub_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 + // CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint_int"(%arg0, %[[V1]]) : (!FHE.esint<2>, i3) -> !FHE.esint<2> + // CHECK-NEXT: return %[[V2]] : !FHE.esint<2> + + %0 = arith.constant 1 : i3 + %1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.esint<2>, i3) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} + // CHECK-LABEL: func.func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> func.func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> @@ -69,6 +147,15 @@ func.func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK-LABEL: func.func @sub_eint_signed(%arg0: !FHE.esint<2>, %arg1: !FHE.esint<2>) -> !FHE.esint<2> +func.func @sub_eint_signed(%arg0: !FHE.esint<2>, %arg1: !FHE.esint<2>) -> !FHE.esint<2> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%arg0, %arg1) : (!FHE.esint<2>, !FHE.esint<2>) -> !FHE.esint<2> + // CHECK-NEXT: return %[[V1]] : !FHE.esint<2> + + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} + // CHECK-LABEL: func.func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func.func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = "FHE.neg_eint"(%arg0) : (!FHE.eint<2>) -> !FHE.eint<2> @@ -89,12 +176,32 @@ func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } -// CHECK-LABEL: func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> -func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { - // CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> +// CHECK-LABEL: func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> +func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 + // CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%arg0, %[[V1]]) : (!FHE.esint<2>, i3) -> !FHE.esint<2> + // CHECK-NEXT: return %[[V2]] : !FHE.esint<2> + + %0 = arith.constant 1 : i3 + %1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.esint<2>, i3) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} + +// CHECK-LABEL: func.func @to_signed(%arg0: !FHE.eint<2>) -> !FHE.esint<2> +func.func @to_signed(%arg0: !FHE.eint<2>) -> !FHE.esint<2> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.to_signed"(%arg0) : (!FHE.eint<2>) -> !FHE.esint<2> + // CHECK-NEXT: return %[[V1]] : !FHE.esint<2> + + %1 = "FHE.to_signed"(%arg0): (!FHE.eint<2>) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} + +// CHECK-LABEL: func.func @to_unsigned(%arg0: !FHE.esint<2>) -> !FHE.eint<2> +func.func @to_unsigned(%arg0: !FHE.esint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.to_unsigned"(%arg0) : (!FHE.esint<2>) -> !FHE.eint<2> // CHECK-NEXT: return %[[V1]] : !FHE.eint<2> - %1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + %1 = "FHE.to_unsigned"(%arg0): (!FHE.esint<2>) -> (!FHE.eint<2>) return %1: !FHE.eint<2> } diff --git a/compiler/tests/check_tests/Dialect/FHE/sub_eint.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/sub_eint.invalid.mlir new file mode 100644 index 000000000..25442ce4a --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/sub_eint.invalid.mlir @@ -0,0 +1,31 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.sub_eint' op should have the width of encrypted inputs equal +func.func @bad_inputs_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.sub_eint' op should have the signedness of encrypted inputs equal +func.func @bad_inputs_signedness(%arg0: !FHE.eint<2>, %arg1: !FHE.esint<2>) -> !FHE.eint<2> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.esint<2>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.sub_eint' op should have the width of encrypted inputs and result equal +func.func @bad_result_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.sub_eint' op should have the signedness of encrypted inputs and result equal +func.func @bad_result_signedness(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.esint<2> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/sub_int_eint.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/sub_int_eint.invalid.mlir new file mode 100644 index 000000000..3b712fdc1 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/sub_int_eint.invalid.mlir @@ -0,0 +1,26 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the width of plain input equal to width of encrypted input + 1 +func.func @bad_clear_width(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + %0 = arith.constant 1 : i4 + %1 = "FHE.sub_int_eint"(%0, %arg0): (i4, !FHE.eint<2>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the width of encrypted inputs and result equal +func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.eint<3> { + %0 = arith.constant 1 : i3 + %1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<3>) + return %1: !FHE.eint<3> +} + +// ----- + +// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the signedness of encrypted inputs and result equal +func.func @bad_result_signedness(%arg0: !FHE.eint<2>) -> !FHE.esint<2> { + %0 = arith.constant 1 : i3 + %1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.esint<2>) + return %1: !FHE.esint<2> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/to_signed.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/to_signed.invalid.mlir new file mode 100644 index 000000000..afd24f3ed --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/to_signed.invalid.mlir @@ -0,0 +1,7 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.to_signed' op should have the width of encrypted input and result equal +func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.esint<3> { + %1 = "FHE.to_signed"(%arg0): (!FHE.eint<2>) -> !FHE.esint<3> + return %1: !FHE.esint<3> +} diff --git a/compiler/tests/check_tests/Dialect/FHE/to_unsigned.invalid.mlir b/compiler/tests/check_tests/Dialect/FHE/to_unsigned.invalid.mlir new file mode 100644 index 000000000..1855494b1 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/to_unsigned.invalid.mlir @@ -0,0 +1,7 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHE.to_unsigned' op should have the width of encrypted input and result equal +func.func @bad_result_width(%arg0: !FHE.esint<2>) -> !FHE.eint<3> { + %1 = "FHE.to_unsigned"(%arg0): (!FHE.esint<2>) -> !FHE.eint<3> + return %1: !FHE.eint<3> +} diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/concat.invalid.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/concat.invalid.mlir index cf71bc912..d4ab4f5f9 100644 --- a/compiler/tests/check_tests/Dialect/FHELinalg/concat.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/FHELinalg/concat.invalid.mlir @@ -19,7 +19,7 @@ func.func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { // ----- func.func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> { - // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}} + // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equal}} %0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> return %0 : tensor<7x!FHE.eint<6>> } @@ -27,7 +27,7 @@ func.func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tenso // ----- func.func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>> { - // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}} + // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equal}} %0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>> return %0 : tensor<7x!FHE.eint<7>> } @@ -35,7 +35,7 @@ func.func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<6>>) -> tenso // ----- func.func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> { - // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}} + // expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equal}} %0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> return %0 : tensor<7x!FHE.eint<7>> } diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/dot.invalid.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/dot.invalid.mlir index 231db48ff..8680e53c4 100644 --- a/compiler/tests/check_tests/Dialect/FHELinalg/dot.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/FHELinalg/dot.invalid.mlir @@ -47,7 +47,7 @@ func.func @dot_incompatible_return( %arg0: tensor<4x!FHE.eint<2>>, %arg1: tensor<4xi3>) -> !FHE.eint<3> { - // expected-error @+1 {{'FHELinalg.dot_eint_int' op should have the width of encrypted inputs and result equals}} + // expected-error @+1 {{'FHELinalg.dot_eint_int' op should have the width of encrypted inputs and result equal}} %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<3> @@ -61,7 +61,7 @@ func.func @dot_incompatible_int( %arg0: tensor<4x!FHE.eint<2>>, %arg1: tensor<4xi4>) -> !FHE.eint<2> { - // expected-error @+1 {{'FHELinalg.dot_eint_int' op should have the width of plain input equals to width of encrypted input + 1}} + // expected-error @+1 {{'FHELinalg.dot_eint_int' op should have the width of plain input equal to width of encrypted input + 1}} %ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) : (tensor<4x!FHE.eint<2>>, tensor<4xi4>) -> !FHE.eint<2> diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/sum.invalid.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/sum.invalid.mlir index 16b334a24..c9e746f73 100644 --- a/compiler/tests/check_tests/Dialect/FHELinalg/sum.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/FHELinalg/sum.invalid.mlir @@ -3,7 +3,7 @@ // ----- func.func @sum_invalid_bitwidth(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<6> { - // expected-error @+1 {{'FHELinalg.sum' op should have the width of encrypted inputs and result equals}} + // expected-error @+1 {{'FHELinalg.sum' op should have the width of encrypted inputs and result equal}} %1 = "FHELinalg.sum"(%arg0): (tensor<4x!FHE.eint<7>>) -> !FHE.eint<6> return %1 : !FHE.eint<6> } diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/to_signed.invalid.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/to_signed.invalid.mlir new file mode 100644 index 000000000..e07f075ba --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHELinalg/to_signed.invalid.mlir @@ -0,0 +1,15 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHELinalg.to_signed' op input and output tensors should have the same width +func.func @bad_result_width(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<3>> { + %1 = "FHELinalg.to_signed"(%arg0): (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<3>> + return %1: tensor<3x2x!FHE.esint<3>> +} + +// ----- + +// CHECK-LABEL: error: 'FHELinalg.to_signed' op input and output tensors should have the same shape +func.func @bad_result_shape(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> { + %1 = "FHELinalg.to_signed"(%arg0): (tensor<3x2x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> + return %1: tensor<3x!FHE.esint<2>> +} diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/to_signed.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/to_signed.mlir new file mode 100644 index 000000000..30d738fbf --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHELinalg/to_signed.mlir @@ -0,0 +1,23 @@ +// RUN: concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// ----- + +// CHECK: func.func @main(%[[a0:.*]]: tensor<3x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.to_signed"(%[[a0]]) : (tensor<3x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<3x!FHE.esint<2>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<3x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> { + %1 = "FHELinalg.to_signed"(%arg0): (tensor<3x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> + return %1 : tensor<3x!FHE.esint<2>> +} + +// ----- + +// CHECK: func.func @main(%[[a0:.*]]: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.to_signed"(%[[a0]]) : (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<3x2x!FHE.esint<2>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>> { + %1 = "FHELinalg.to_signed"(%arg0): (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>> + return %1 : tensor<3x2x!FHE.esint<2>> +} diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/to_unsigned.invalid.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/to_unsigned.invalid.mlir new file mode 100644 index 000000000..2d986a192 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHELinalg/to_unsigned.invalid.mlir @@ -0,0 +1,15 @@ +// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// CHECK-LABEL: error: 'FHELinalg.to_unsigned' op input and output tensors should have the same width +func.func @bad_result_width(%arg0: tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<3>> { + %1 = "FHELinalg.to_unsigned"(%arg0): (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<3>> + return %1: tensor<3x2x!FHE.eint<3>> +} + +// ----- + +// CHECK-LABEL: error: 'FHELinalg.to_unsigned' op input and output tensors should have the same shape +func.func @bad_result_shape(%arg0: tensor<3x2x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> { + %1 = "FHELinalg.to_unsigned"(%arg0): (tensor<3x2x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> + return %1: tensor<3x!FHE.eint<2>> +} diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/to_unsigned.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/to_unsigned.mlir new file mode 100644 index 000000000..cbe248fff --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHELinalg/to_unsigned.mlir @@ -0,0 +1,23 @@ +// RUN: concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// ----- + +// CHECK: func.func @main(%[[a0:.*]]: tensor<3x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.to_unsigned"(%[[a0]]) : (tensor<3x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<3x!FHE.eint<2>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<3x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> { + %1 = "FHELinalg.to_unsigned"(%arg0): (tensor<3x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> + return %1 : tensor<3x!FHE.eint<2>> +} + +// ----- + +// CHECK: func.func @main(%[[a0:.*]]: tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.to_unsigned"(%[[a0]]) : (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>> +// CHECK-NEXT: return %[[v0]] : tensor<3x2x!FHE.eint<2>> +// CHECK-NEXT: } +func.func @main(%arg0: tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>> { + %1 = "FHELinalg.to_unsigned"(%arg0): (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>> + return %1 : tensor<3x2x!FHE.eint<2>> +}