feat: create encrypted signed integer type

This commit is contained in:
Umut
2022-08-26 14:17:58 +03:00
parent 39e7313348
commit 41c9f86803
39 changed files with 1144 additions and 260 deletions

View File

@@ -1,3 +1,7 @@
set(LLVM_TARGET_DEFINITIONS FHEInterfaces.td)
mlir_tablegen(FHETypesInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(FHETypesInterfaces.cpp.inc -gen-type-interface-defs)
set(LLVM_TARGET_DEFINITIONS FHEOps.td)
mlir_tablegen(FHEOps.h.inc -gen-op-decls)
mlir_tablegen(FHEOps.cpp.inc -gen-op-defs)

View File

@@ -0,0 +1,32 @@
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES
#define CONCRETELANG_DIALECT_FHE_IR_FHE_INTERFACES
include "mlir/IR/OpBase.td"
def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> {
let cppNamespace = "mlir::concretelang::FHE";
let description = [{
Interface for encapsulating the common properties of encrypted integer types.
}];
let methods = [
InterfaceMethod<
/*description=*/"Get bit-width of the integer.",
/*retTy=*/"unsigned",
/*methodName=*/"getWidth"
>,
InterfaceMethod<
/*description=*/"Get whether the integer is signed.",
/*retTy=*/"bool",
/*methodName=*/"isSigned"
>,
InterfaceMethod<
/*description=*/"Get whether the integer is unsigned.",
/*retTy=*/"bool",
/*methodName=*/"isUnsigned"
>
];
}
#endif

View File

@@ -18,10 +18,10 @@ namespace concretelang {
namespace FHE {
bool verifyEncryptedIntegerInputAndResultConsistency(
Operation &op, EncryptedIntegerType &input, EncryptedIntegerType &result);
Operation &op, FheIntegerInterface &input, FheIntegerInterface &result);
bool verifyEncryptedIntegerAndIntegerInputsConsistency(Operation &op,
EncryptedIntegerType &a,
FheIntegerInterface &a,
IntegerType &b);
/// Shared error message for all ApplyLookupTable variant Op (several Dialect)

View File

@@ -27,14 +27,14 @@ def FHE_ZeroEintOp : FHE_Op<"zero", [NoSideEffect]> {
Example:
```mlir
"FHE.zero"() : () -> !FHE.eint<2>
"FHE.zero"() : () -> !FHE.esint<2>
```
}];
let arguments = (ins);
let results = (outs FHE_EncryptedIntegerType:$out);
let results = (outs FHE_AnyEncryptedInteger:$out);
}
def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> {
let summary = "Creates a new tensor with all elements initialized to an encrypted zero.";
@@ -44,36 +44,38 @@ def FHE_ZeroTensorOp : FHE_Op<"zero_tensor", [NoSideEffect]> {
Example:
```mlir
%tensor = "FHE.zero_tensor"() : () -> tensor<5x!FHE.eint<4>>
%tensor = "FHE.zero_tensor"() : () -> tensor<5x!FHE.esint<4>>
```
}];
let arguments = (ins);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$tensor);
}
def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> {
let summary = "Adds an encrypted integer and a clear integer";
let description = [{
Adds an encrypted integer and a clear integer.
The clear integer must have at most one more bit than the encrypted integer
and the result must have the same width than the encrypted integer.
and the result must have the same width and the same signedness as the encrypted integer.
Example:
```mlir
// ok
"FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
"FHE.add_eint_int"(%a, %i) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
// error
"FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i4) -> !FHE.eint<2>
"FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<3>
"FHE.add_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.esint<2>
```
}];
let arguments = (ins FHE_EncryptedIntegerType:$a, AnyInteger:$b);
let results = (outs FHE_EncryptedIntegerType);
let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
@@ -86,26 +88,28 @@ def FHE_AddEintIntOp : FHE_Op<"add_eint_int", [NoSideEffect]> {
}
def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> {
let summary = "Adds two encrypted integers";
let description = [{
let description = [{
Adds two encrypted integers
The encrypted integers and the result must have the same width.
The encrypted integers and the result must have the same width and the same signedness.
Example:
```mlir
// ok
"FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
"FHE.add_eint"(%a, %b): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>)
// error
"FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
"FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
"FHE.add_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
"FHE.add_eint"(%a, %b): (!FHE.esint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
```
}];
let arguments = (ins FHE_EncryptedIntegerType:$a, FHE_EncryptedIntegerType:$b);
let results = (outs FHE_EncryptedIntegerType);
let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
@@ -117,27 +121,28 @@ def FHE_AddEintOp : FHE_Op<"add_eint", [NoSideEffect]> {
}
def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> {
let summary = "Substract a clear integer and an encrypted integer";
let summary = "Subtract an encrypted integer from a clear integer";
let description = [{
Substract a clear integer and an encrypted integer.
The clear integer must have at most one more bit than the encrypted integer
and the result must have the same width than the encrypted integer.
Subtract an encrypted integer from a clear integer.
The clear integer must have one more bit than the encrypted integer
and the result must have the same width and the same signedness as the encrypted integer.
Example:
```mlir
// ok
"FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.eint<2>
"FHE.sub_int_eint"(%i, %a) : (i3, !FHE.esint<2>) -> !FHE.esint<2>
// error
"FHE.sub_int_eint"(%i, %a) : (i4, !FHE.eint<2>) -> !FHE.eint<2>
"FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.eint<3>
"FHE.sub_int_eint"(%i, %a) : (i3, !FHE.eint<2>) -> !FHE.esint<2>
```
}];
let arguments = (ins AnyInteger:$a, FHE_EncryptedIntegerType:$b);
let results = (outs FHE_EncryptedIntegerType);
let arguments = (ins AnyInteger:$a, FHE_AnyEncryptedInteger:$b);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
@@ -149,27 +154,28 @@ def FHE_SubIntEintOp : FHE_Op<"sub_int_eint", [NoSideEffect]> {
}
def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> {
let summary = "Substract a clear integer from an encrypted integer";
let summary = "Subtract a clear integer from an encrypted integer";
let description = [{
Substract a clear integer from an encrypted integer.
The clear integer must have at most one more bit than the encrypted integer
and the result must have the same width than the encrypted integer.
Subtract a clear integer from an encrypted integer.
The clear integer must have one more bit than the encrypted integer
and the result must have the same width and the same signedness as the encrypted integer.
Example:
```mlir
// ok
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
"FHE.sub_eint_int"(%i, %a) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
// error
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i4) -> !FHE.eint<2>
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<3>
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.esint<2>
```
}];
let arguments = (ins FHE_EncryptedIntegerType:$a, AnyInteger:$b);
let results = (outs FHE_EncryptedIntegerType);
let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
@@ -178,31 +184,32 @@ def FHE_SubEintIntOp : FHE_Op<"sub_eint_int", [NoSideEffect]> {
];
let hasVerifier = 1;
let hasFolder = 1;
}
def FHE_SubEintOp : FHE_Op<"sub_eint", [NoSideEffect]> {
let summary = "Subtracts two encrypted integers";
let summary = "Subtract an encrypted integer from an encrypted integer";
let description = [{
Subtracts two encrypted integers
The encrypted integers and the result must have the same width.
Subtract an encrypted integer from an encrypted integer.
The encrypted integers and the result must have the same width and the same signedness.
Example:
```mlir
// ok
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
"FHE.sub_eint"(%a, %b): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>)
// error
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.esint<2>) -> (!FHE.esint<2>)
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
```
}];
let arguments = (ins FHE_EncryptedIntegerType:$a, FHE_EncryptedIntegerType:$b);
let results = (outs FHE_EncryptedIntegerType);
let arguments = (ins FHE_AnyEncryptedInteger:$a, FHE_AnyEncryptedInteger:$b);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
@@ -219,20 +226,22 @@ def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> {
let description = [{
Negates an encrypted integer.
The result must have the same width than the encrypted integer.
The result must have the same width and the same signedness as the encrypted integer.
Example:
```mlir
// ok
"FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.eint<2>)
"FHE.neg_eint"(%a): (!FHE.esint<2>) -> (!FHE.esint<2>)
// error
"FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.eint<3>)
"FHE.neg_eint"(%a): (!FHE.eint<2>) -> (!FHE.esint<2>)
```
}];
let arguments = (ins FHE_EncryptedIntegerType:$a);
let results = (outs FHE_EncryptedIntegerType);
let arguments = (ins FHE_AnyEncryptedInteger:$a);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$a), [{
@@ -243,27 +252,28 @@ def FHE_NegEintOp : FHE_Op<"neg_eint", [NoSideEffect]> {
}
def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
let summary = "Mulitplies an encrypted integer and a clear integer";
let summary = "Multiply an encrypted integer with a clear integer";
let description = [{
Mulitplies an encrypted integer and a clear integer.
The clear integer must have at most one more bit than the encrypted integer
and the result must have the same width than the encrypted integer.
Multiply an encrypted integer with a clear integer.
The clear integer must have one more bit than the encrypted integer
and the result must have the same width and the same signedness as the encrypted integer.
Example:
```mlir
// ok
"FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
"FHE.mul_eint_int"(%a, %i) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
// error
"FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i4) -> !FHE.eint<2>
"FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<3>
"FHE.mul_eint_int"(%a, %i) : (!FHE.eint<2>, i3) -> !FHE.esint<2>
```
}];
let arguments = (ins FHE_EncryptedIntegerType:$a, AnyInteger:$b);
let results = (outs FHE_EncryptedIntegerType);
let arguments = (ins FHE_AnyEncryptedInteger:$a, AnyInteger:$b);
let results = (outs FHE_AnyEncryptedInteger);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
@@ -275,6 +285,56 @@ def FHE_MulEintIntOp : FHE_Op<"mul_eint_int", [NoSideEffect]> {
let hasFolder = 1;
}
def FHE_ToSignedOp : FHE_Op<"to_signed", [NoSideEffect]> {
let summary = "Cast an unsigned integer to a signed one";
let description = [{
Cast an unsigned integer to a signed one.
The result must have the same width as the input.
The behavior is undefined on overflow/underflow.
Examples:
```mlir
// ok
"FHE.to_signed"(%x) : (!FHE.eint<2>) -> !FHE.esint<2>
// error
"FHE.to_signed"(%x) : (!FHE.eint<2>) -> !FHE.esint<3>
```
}];
let arguments = (ins FHE_EncryptedIntegerType:$input);
let results = (outs FHE_EncryptedSignedIntegerType);
let hasVerifier = 1;
}
def FHE_ToUnsignedOp : FHE_Op<"to_unsigned", [NoSideEffect]> {
let summary = "Cast a signed integer to an unsigned one";
let description = [{
Cast a signed integer to an unsigned one.
The result must have the same width as the input.
The behavior is undefined on overflow/underflow.
Examples:
```mlir
// ok
"FHE.to_unsigned"(%x) : (!FHE.esint<2>) -> !FHE.eint<2>
// error
"FHE.to_unsigned"(%x) : (!FHE.esint<2>) -> !FHE.eint<3>
```
}];
let arguments = (ins FHE_EncryptedSignedIntegerType:$input);
let results = (outs FHE_EncryptedIntegerType);
let hasVerifier = 1;
}
def FHE_ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table", [NoSideEffect]> {
let summary = "Applies a clear lookup table to an encrypted integer";

View File

@@ -11,6 +11,8 @@
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/IR/DialectImplementation.h>
#include "concretelang/Dialect/FHE/IR/FHETypesInterfaces.h.inc"
#define GET_TYPEDEF_CLASSES
#include "concretelang/Dialect/FHE/IR/FHEOpsTypes.h.inc"

View File

@@ -2,13 +2,14 @@
#define CONCRETELANG_DIALECT_FHE_IR_FHE_TYPES
include "concretelang/Dialect/FHE/IR/FHEDialect.td"
include "concretelang/Dialect/FHE/IR/FHEInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
class FHE_Type<string name, list<Trait> traits = []> :
TypeDef<FHE_Dialect, name, traits> { }
def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger",
[MemRefElementTypeInterface]> {
[MemRefElementTypeInterface, FheIntegerInterface]> {
let mnemonic = "eint";
let summary = "An encrypted integer";
@@ -28,6 +29,44 @@ def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger",
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = true;
let extraClassDeclaration = [{
bool isSigned() const { return false; }
bool isUnsigned() const { return true; }
}];
}
def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger",
[MemRefElementTypeInterface, FheIntegerInterface]> {
let mnemonic = "esint";
let summary = "An encrypted signed integer";
let description = [{
An encrypted signed integer with `width` bits to performs FHE Operations.
Examples:
```mlir
!FHE.esint<7>
!FHE.esint<6>
```
}];
let parameters = (ins "unsigned":$width);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = true;
let extraClassDeclaration = [{
bool isSigned() const { return true; }
bool isUnsigned() const { return false; }
}];
}
def FHE_AnyEncryptedInteger : Type<Or<[
FHE_EncryptedIntegerType.predicate,
FHE_EncryptedSignedIntegerType.predicate
]>>;
#endif

View File

@@ -22,8 +22,8 @@ def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRul
let summary = "Returns a tensor that contains the addition of a tensor of encrypted integers and a tensor of clear integers.";
let description = [{
Performs an addition follwing the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers.
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
Performs an addition following the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers.
The width of the clear integers must be less than or equals to the width of encrypted integers.
Examples:
```mlir
@@ -58,11 +58,11 @@ def FHELinalg_AddEintIntOp : FHELinalg_Op<"add_eint_int", [TensorBroadcastingRul
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{
@@ -77,7 +77,7 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, Ten
let summary = "Returns a tensor that contains the addition of two tensor of encrypted integers.";
let description = [{
Performs an addition follwing the broadcasting rules between two tensors of encrypted integers.
Performs an addition following the broadcasting rules between two tensors of encrypted integers.
The width of the encrypted integers must be equals.
Examples:
@@ -112,11 +112,11 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, Ten
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$rhs
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{
@@ -126,21 +126,21 @@ def FHELinalg_AddEintOp : FHELinalg_Op<"add_eint", [TensorBroadcastingRules, Ten
}
def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, TensorBinaryIntEint]> {
let summary = "Returns a tensor that contains the substraction of a tensor of clear integers and a tensor of encrypted integers.";
let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers and a tensor of encrypted integers.";
let description = [{
Performs a substraction following the broadcasting rules between a tensor of clear integers and a tensor of encrypted integers.
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
Performs a subtraction following the broadcasting rules between a tensor of clear integers and a tensor of encrypted integers.
The width of the clear integers must be less than or equals to the width of encrypted integers.
Examples:
```mlir
// Returns the term to term substraction of `%a0` with `%a1`
// Returns the term to term subtraction of `%a0` with `%a1`
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4xi5>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
// Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
// Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<4x1x4xi5>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
// Returns the subtraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
//
// [1,2,3] [1] [0,2,3]
// [4,5,6] - [2] = [2,3,4]
@@ -149,7 +149,7 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRul
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"FHELinalg.sub_int_eint"(%a0, %a1) : (tensor<3x3xi5>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
// Returns the subtraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
//
// [1,2,3] [0,0,0]
// [4,5,6] - [1,2,3] = [3,3,3]
@@ -166,10 +166,10 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRul
let arguments = (ins
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$rhs
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$rhs, "Value":$lhs), [{
@@ -179,21 +179,21 @@ def FHELinalg_SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRul
}
def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
let summary = "Returns a tensor that contains the substraction of a tensor of clear integers from a tensor of encrypted integers.";
let summary = "Returns a tensor that contains the subtraction of a tensor of clear integers from a tensor of encrypted integers.";
let description = [{
Performs a substraction following the broadcasting rules between a tensor of clear integers from a tensor of encrypted integers.
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
Performs a subtraction following the broadcasting rules between a tensor of clear integers from a tensor of encrypted integers.
The width of the clear integers must be less than or equals to the width of encrypted integers.
Examples:
```mlir
// Returns the term to term substraction of `%a0` with `%a1`
// Returns the term to term subtraction of `%a0` with `%a1`
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
// Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
// Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<1x4x4x!FHE.eint<4>>, tensor<4x1x4xi5>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
// Returns the subtraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
//
// [1,2,3] [1] [0,2,3]
// [4,5,6] - [2] = [2,3,4]
@@ -202,7 +202,7 @@ def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRul
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x1x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
// Returns the subtraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
//
// [1,2,3] [0,0,0]
// [4,5,6] - [1,2,3] = [3,3,3]
@@ -218,11 +218,11 @@ def FHELinalg_SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRul
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
@@ -238,7 +238,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten
let summary = "Returns a tensor that contains the subtraction of two tensor of encrypted integers.";
let description = [{
Performs an subtraction follwing the broadcasting rules between two tensors of encrypted integers.
Performs an subtraction following the broadcasting rules between two tensors of encrypted integers.
The width of the encrypted integers must be equal.
Examples:
@@ -249,7 +249,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten
// Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
// Returns the subtraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
//
// [1,2,3] [1] [0,2,3]
// [4,5,6] - [2] = [2,3,4]
@@ -258,7 +258,7 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
// Returns the subtraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
//
// [1,2,3] [0,0,0]
// [4,5,6] - [1,2,3] = [3,3,3]
@@ -273,11 +273,11 @@ def FHELinalg_SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, Ten
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$rhs
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
@@ -306,10 +306,10 @@ def FHELinalg_NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$tensor
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$tensor), [{
@@ -323,7 +323,7 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRul
let description = [{
Performs a multiplication following the broadcasting rules between a tensor of encrypted integers and a tensor of clear integers.
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
The width of the clear integers must be less than or equals to the width of encrypted integers.
Examples:
```mlir
@@ -358,11 +358,11 @@ def FHELinalg_MulEintIntOp : FHELinalg_Op<"mul_eint_int", [TensorBroadcastingRul
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasFolder = 1;
}
@@ -394,11 +394,11 @@ def FHELinalg_ApplyLookupTableEintOp : FHELinalg_Op<"apply_lookup_table", []> {
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$t,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$t,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$lut
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
}
@@ -519,10 +519,10 @@ def FHELinalg_Dot : FHELinalg_Op<"dot_eint_int"> {
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$lhs,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred, HasAnyRankOfPred<[1]>]>>:$rhs);
let results = (outs FHE_EncryptedIntegerType:$out);
let results = (outs FHE_AnyEncryptedInteger:$out);
let hasVerifier = 1;
}
@@ -656,11 +656,11 @@ def FHELinalg_MatMulEintIntOp : FHELinalg_Op<"matmul_eint_int", [TensorBinaryEin
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
}
@@ -795,10 +795,10 @@ def FHELinalg_MatMulIntEintOp : FHELinalg_Op<"matmul_int_eint", [TensorBinaryInt
let arguments = (ins
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$rhs
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
}
@@ -871,15 +871,15 @@ def FHELinalg_SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$tensor,
DefaultValuedAttr<I64ArrayAttr, "{}">:$axes,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs
TypeConstraint<Or<[
FHE_EncryptedIntegerType.predicate,
And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>
FHE_AnyEncryptedInteger.predicate,
And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>
]>>:$out
);
@@ -917,12 +917,12 @@ def FHELinalg_ConcatOp : FHELinalg_Op<"concat"> {
}];
let arguments = (ins
Variadic<Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>>:$ins,
Variadic<Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>>:$ins,
DefaultValuedAttr<I64Attr, "0">:$axis
);
let results = (outs
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$out
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$out
);
let hasVerifier = 1;
@@ -931,7 +931,7 @@ def FHELinalg_ConcatOp : FHELinalg_Op<"concat"> {
def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
let summary = "Returns the 2D convolution of a tensor in the form NCHW with weights in the form FCHW";
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$input,
Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>:$input,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$weight,
Optional<Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>>:$bias,
// Since there is no U64ElementsAttr, we use I64 and make sure there is no neg values during verification
@@ -940,7 +940,7 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
OptionalAttr<I64ElementsAttr>:$dilations,
OptionalAttr<I64Attr>:$group
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let results = (outs Type<And<[TensorOf<[FHE_AnyEncryptedInteger]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;
}
@@ -989,4 +989,64 @@ def FHELinalg_FromElementOp : FHELinalg_Op<"from_element", []> {
let hasVerifier = 1;
}
def FHELinalg_ToSignedOp : FHELinalg_Op<"to_signed", []> {
let summary = "Cast an unsigned integer tensor to a signed one";
let description = [{
Cast an unsigned integer tensor to a signed one.
The result must have the same width and the same shape as the input.
The behavior is undefined on overflow/underflow.
Examples:
```mlir
// ok
"FHELinalg.to_signed"(%x) : (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>>
// error
"FHELinalg.to_signed"(%x) : (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<3>>
```
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$input
);
let results = (outs
Type<And<[TensorOf<[FHE_EncryptedSignedIntegerType]>.predicate, HasStaticShapePred]>>:$output
);
let hasVerifier = 1;
}
def FHELinalg_ToUnsignedOp : FHELinalg_Op<"to_unsigned", []> {
let summary = "Cast a signed integer tensor to an unsigned one";
let description = [{
Cast a signed integer tensor to an unsigned one.
The result must have the same width and the same shape as the input.
The behavior is undefined on overflow/underflow.
Examples:
```mlir
// ok
"FHELinalg.to_unsigned"(%x) : (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>>
// error
"FHELinalg.to_unsigned"(%x) : (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<3>>
```
}];
let arguments = (ins
Type<And<[TensorOf<[FHE_EncryptedSignedIntegerType]>.predicate, HasStaticShapePred]>>:$input
);
let results = (outs
Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$output
);
let hasVerifier = 1;
}
#endif

View File

@@ -1775,6 +1775,184 @@ struct FHELinalgConv2dToLinalgConv2d
};
};
/// This template rewrite pattern transforms any instance of
/// operators `FHELinalg.to_signed` to an instance of `linalg.generic` with an
/// appropriate region using `FHE.to_signed` operation, an appropriate
/// specification for the iteration dimensions and appropriate operations
/// managing the accumulator of `linalg.generic`.
///
/// Example:
///
/// FHELinalg.to_signed(%tensor):
/// tensor<DNx...xD1x!FHE.eint<p>> -> tensor<DNx...xD1x!FHE.esint<p>>
///
/// becomes:
///
/// #maps = [
/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>,
/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>
/// ]
/// #attributes {
/// indexing_maps = #maps,
/// iterator_types = ["parallel", "parallel"],
/// }
///
/// %init = linalg.init_tensor [DN,...,D1] : tensor<DNx...xD1x!FHE.esint<p>>
/// %result = linalg.generic {
/// ins(%tensor: tensor<DNx...xD1x!FHE.eint<p>>)
/// outs(%init: tensor<DNx...xD1x!FHE.esint<p>>)
/// {
/// ^bb0(%arg0: !FHE.eint<p>):
/// %0 = FHE.to_signed(%arg0): !FHE.eint<p> -> !FHE.esint<p>
/// linalg.yield %0 : !FHE.esint<p>
/// }
/// }
///
struct FHELinalgToSignedToLinalgGeneric
: public mlir::OpRewritePattern<FHELinalg::ToSignedOp> {
FHELinalgToSignedToLinalgGeneric(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<FHELinalg::ToSignedOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHELinalg::ToSignedOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType inputTy =
op.input().getType().cast<mlir::RankedTensorType>();
mlir::RankedTensorType resultTy =
op->getResult(0).getType().cast<mlir::RankedTensorType>();
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
llvm::SmallVector<mlir::AffineMap, 2> maps{
mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(),
this->getContext()),
mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(),
this->getContext()),
};
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
auto fheOp = nestedBuilder.create<FHE::ToSignedOp>(
op.getLoc(), resultTy.getElementType(), blockArgs[0]);
nestedBuilder.create<mlir::linalg::YieldOp>(op.getLoc(),
fheOp.getResult());
};
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 1> ins{op.input()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
auto genericOp = rewriter.create<mlir::linalg::GenericOp>(
op.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call,
bodyBuilder);
rewriter.replaceOp(op, {genericOp.getResult(0)});
return mlir::success();
};
};
/// This template rewrite pattern transforms any instance of
/// operators `FHELinalg.to_unsigned` to an instance of `linalg.generic` with an
/// appropriate region using `FHE.to_unsigned` operation, an appropriate
/// specification for the iteration dimensions and appropriate operations
/// managing the accumulator of `linalg.generic`.
///
/// Example:
///
/// FHELinalg.to_unsigned(%tensor):
/// tensor<DNx...xD1x!FHE.esint<p>> -> tensor<DNx...xD1x!FHE.eint<p>>
///
/// becomes:
///
/// #maps = [
/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>,
/// affine_map<(aN, ..., a1) -> (aN, ..., a1)>
/// ]
/// #attributes {
/// indexing_maps = #maps,
/// iterator_types = ["parallel", "parallel"],
/// }
///
/// %init = linalg.init_tensor [DN,...,D1] : tensor<DNx...xD1x!FHE.eint<p>>
/// %result = linalg.generic {
/// ins(%tensor: tensor<DNx...xD1x!FHE.esint<p>>)
/// outs(%init: tensor<DNx...xD1x!FHE.eint<p>>)
/// {
/// ^bb0(%arg0: !FHE.esint<p>):
/// %0 = FHE.to_unsigned(%arg0): !FHE.esint<p> -> !FHE.eint<p>
/// linalg.yield %0 : !FHE.eint<p>
/// }
/// }
///
struct FHELinalgToUnsignedToLinalgGeneric
: public mlir::OpRewritePattern<FHELinalg::ToUnsignedOp> {
FHELinalgToUnsignedToLinalgGeneric(
mlir::MLIRContext *context,
mlir::PatternBenefit benefit =
mlir::concretelang::DEFAULT_PATTERN_BENEFIT)
: mlir::OpRewritePattern<FHELinalg::ToUnsignedOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(FHELinalg::ToUnsignedOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::RankedTensorType inputTy =
op.input().getType().cast<mlir::RankedTensorType>();
mlir::RankedTensorType resultTy =
op->getResult(0).getType().cast<mlir::RankedTensorType>();
mlir::Value init = rewriter.create<bufferization::AllocTensorOp>(
op.getLoc(), resultTy, mlir::ValueRange{});
llvm::SmallVector<mlir::AffineMap, 2> maps{
mlir::AffineMap::getMultiDimIdentityMap(inputTy.getShape().size(),
this->getContext()),
mlir::AffineMap::getMultiDimIdentityMap(resultTy.getShape().size(),
this->getContext()),
};
llvm::SmallVector<llvm::StringRef> iteratorTypes(resultTy.getShape().size(),
"parallel");
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
auto fheOp = nestedBuilder.create<FHE::ToUnsignedOp>(
op.getLoc(), resultTy.getElementType(), blockArgs[0]);
nestedBuilder.create<mlir::linalg::YieldOp>(op.getLoc(),
fheOp.getResult());
};
llvm::SmallVector<mlir::Type, 1> resTypes{init.getType()};
llvm::SmallVector<mlir::Value, 1> ins{op.input()};
llvm::SmallVector<mlir::Value, 1> outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
auto genericOp = rewriter.create<mlir::linalg::GenericOp>(
op.getLoc(), resTypes, ins, outs, maps, iteratorTypes, doc, call,
bodyBuilder);
rewriter.replaceOp(op, {genericOp.getResult(0)});
return mlir::success();
};
};
namespace {
struct FHETensorOpsToLinalg
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
@@ -1847,6 +2025,8 @@ void FHETensorOpsToLinalg::runOnOperation() {
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
patterns.insert<TransposeToLinalgGeneric>(&getContext());
patterns.insert<FromElementToTensorFromElements>(&getContext());
patterns.insert<FHELinalgToSignedToLinalgGeneric>(&getContext());
patterns.insert<FHELinalgToUnsignedToLinalgGeneric>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())

View File

@@ -7,6 +7,8 @@
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/FHE/IR/FHETypesInterfaces.cpp.inc"
#define GET_TYPEDEF_CLASSES
#include "concretelang/Dialect/FHE/IR/FHEOpsTypes.cpp.inc"
@@ -31,7 +33,7 @@ void FHEDialect::initialize() {
mlir::LogicalResult EncryptedIntegerType::verify(
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) {
if (p == 0) {
emitError() << "FHE.eint didn't support precision equals to 0";
emitError() << "FHE.eint doesn't support precision of 0";
return mlir::failure();
}
return mlir::success();
@@ -57,3 +59,33 @@ mlir::Type EncryptedIntegerType::parse(mlir::AsmParser &p) {
return getChecked(loc, loc.getContext(), width);
}
mlir::LogicalResult EncryptedSignedIntegerType::verify(
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) {
if (p == 0) {
emitError() << "FHE.esint doesn't support precision of 0";
return mlir::failure();
}
return mlir::success();
}
void EncryptedSignedIntegerType::print(mlir::AsmPrinter &p) const {
p << "<" << getWidth() << ">";
}
mlir::Type EncryptedSignedIntegerType::parse(mlir::AsmParser &p) {
if (p.parseLess())
return mlir::Type();
int width;
if (p.parseInteger(width))
return mlir::Type();
if (p.parseGreater())
return mlir::Type();
mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc());
return getChecked(loc, loc.getContext(), width);
}

View File

@@ -14,113 +14,144 @@ namespace concretelang {
namespace FHE {
bool verifyEncryptedIntegerInputAndResultConsistency(
::mlir::Operation &op, EncryptedIntegerType &input,
EncryptedIntegerType &result) {
if (input.getWidth() != result.getWidth()) {
mlir::Operation &op, FheIntegerInterface &input,
FheIntegerInterface &result) {
if (input.isSigned() != result.isSigned()) {
op.emitOpError(
" should have the width of encrypted inputs and result equals");
"should have the signedness of encrypted inputs and result equal");
return false;
}
if (input.getWidth() != result.getWidth()) {
op.emitOpError(
"should have the width of encrypted inputs and result equal");
return false;
}
return true;
}
bool verifyEncryptedIntegerAndIntegerInputsConsistency(::mlir::Operation &op,
EncryptedIntegerType &a,
bool verifyEncryptedIntegerAndIntegerInputsConsistency(mlir::Operation &op,
FheIntegerInterface &a,
IntegerType &b) {
if (a.getWidth() + 1 != b.getWidth()) {
op.emitOpError(" should have the width of plain input equals to width of "
op.emitOpError("should have the width of plain input equal to width of "
"encrypted input + 1");
return false;
}
return true;
}
bool verifyEncryptedIntegerInputsConsistency(::mlir::Operation &op,
EncryptedIntegerType &a,
EncryptedIntegerType &b) {
if (a.getWidth() != b.getWidth()) {
op.emitOpError(" should have the width of encrypted inputs equals");
bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
FheIntegerInterface &a,
FheIntegerInterface &b) {
if (a.isSigned() != b.isSigned()) {
op.emitOpError("should have the signedness of encrypted inputs equal");
return false;
}
if (a.getWidth() != b.getWidth()) {
op.emitOpError("should have the width of encrypted inputs equal");
return false;
}
return true;
}
::mlir::LogicalResult AddEintIntOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
mlir::LogicalResult AddEintIntOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().cast<IntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
return mlir::failure();
}
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(),
a, b)) {
return ::mlir::failure();
return mlir::failure();
}
return ::mlir::success();
return mlir::success();
}
::mlir::LogicalResult AddEintOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
auto b = this->b().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
mlir::LogicalResult AddEintOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) {
return ::mlir::failure();
}
return ::mlir::success();
}
::mlir::LogicalResult SubIntEintOp::verify() {
mlir::LogicalResult SubIntEintOp::verify() {
auto a = this->a().getType().cast<IntegerType>();
auto b = this->b().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), b,
out)) {
return ::mlir::failure();
return mlir::failure();
}
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(),
b, a)) {
return ::mlir::failure();
return mlir::failure();
}
return ::mlir::success();
return mlir::success();
}
::mlir::LogicalResult SubEintIntOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
mlir::LogicalResult SubEintIntOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().cast<IntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
return mlir::failure();
}
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(),
a, b)) {
return ::mlir::failure();
return mlir::failure();
}
return ::mlir::success();
return mlir::success();
}
::mlir::LogicalResult SubEintOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
auto b = this->b().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
mlir::LogicalResult SubEintOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) {
return ::mlir::failure();
}
return ::mlir::success();
}
::mlir::LogicalResult NegEintOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
mlir::LogicalResult NegEintOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
@@ -128,19 +159,48 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::Operation &op,
return ::mlir::success();
}
::mlir::LogicalResult MulEintIntOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
mlir::LogicalResult MulEintIntOp::verify() {
auto a = this->a().getType().dyn_cast<FheIntegerInterface>();
auto b = this->b().getType().cast<IntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().dyn_cast<FheIntegerInterface>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
return mlir::failure();
}
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(),
a, b)) {
return ::mlir::failure();
return mlir::failure();
}
return ::mlir::success();
return mlir::success();
}
mlir::LogicalResult ToSignedOp::verify() {
auto input = this->input().getType().cast<EncryptedIntegerType>();
auto output = this->getResult().getType().cast<EncryptedSignedIntegerType>();
if (input.getWidth() != output.getWidth()) {
this->emitOpError(
"should have the width of encrypted input and result equal");
return mlir::failure();
}
return mlir::success();
}
mlir::LogicalResult ToUnsignedOp::verify() {
auto input = this->input().getType().cast<EncryptedSignedIntegerType>();
auto output = this->getResult().getType().cast<EncryptedIntegerType>();
if (input.getWidth() != output.getWidth()) {
this->emitOpError(
"should have the width of encrypted input and result equal");
return mlir::failure();
}
return mlir::success();
}
::mlir::LogicalResult ApplyLookupTableEintOp::verify() {

View File

@@ -110,32 +110,38 @@ LogicalResult verifyTensorBinaryEintInt(mlir::Operation *op) {
op->emitOpError() << "should have exactly 2 operands";
return mlir::failure();
}
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null<mlir::TensorType>();
if (op0Ty == nullptr || op1Ty == nullptr) {
op->emitOpError() << "should have both operands as tensor";
return mlir::failure();
}
auto el0Ty =
op0Ty.getElementType()
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
.dyn_cast_or_null<mlir::concretelang::FHE::FheIntegerInterface>();
if (el0Ty == nullptr) {
op->emitOpError() << "should have a !FHE.eint as the element type of the "
"tensor of operand #0";
op->emitOpError()
<< "should have !FHE.eint or !FHE.esint as the element type of the "
"tensor of operand #0";
return mlir::failure();
}
auto el1Ty = op1Ty.getElementType().dyn_cast_or_null<mlir::IntegerType>();
if (el1Ty == nullptr) {
op->emitOpError() << "should have an integer as the element type of the "
"tensor of operand #1";
return mlir::failure();
}
if (el1Ty.getWidth() > el0Ty.getWidth() + 1) {
op->emitOpError()
<< "should have the width of integer values less or equals "
"than the width of encrypted values + 1";
return mlir::failure();
}
return mlir::success();
}
@@ -144,32 +150,38 @@ LogicalResult verifyTensorBinaryIntEint(mlir::Operation *op) {
op->emitOpError() << "should have exactly 2 operands";
return mlir::failure();
}
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null<mlir::TensorType>();
if (op0Ty == nullptr || op1Ty == nullptr) {
op->emitOpError() << "should have both operands as tensor";
return mlir::failure();
}
auto el0Ty = op0Ty.getElementType().dyn_cast_or_null<mlir::IntegerType>();
if (el0Ty == nullptr) {
op->emitOpError() << "should have an integer as the element type of the "
"tensor of operand #0";
return mlir::failure();
}
auto el1Ty =
op1Ty.getElementType()
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
.dyn_cast_or_null<mlir::concretelang::FHE::FheIntegerInterface>();
if (el1Ty == nullptr) {
op->emitOpError() << "should have a !FHE.eint as the element type of the "
"tensor of operand #1";
op->emitOpError()
<< "should have !FHE.eint or !FHE.esint as the element type of the "
"tensor of operand #1";
return mlir::failure();
}
if (el1Ty.getWidth() > el0Ty.getWidth() + 1) {
op->emitOpError()
<< "should have the width of integer values less or equals "
"than the width of encrypted values + 1";
return mlir::failure();
}
return mlir::success();
}
@@ -178,34 +190,50 @@ LogicalResult verifyTensorBinaryEint(mlir::Operation *op) {
op->emitOpError() << "should have exactly 2 operands";
return mlir::failure();
}
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
auto op1Ty = op->getOperand(1).getType().dyn_cast_or_null<mlir::TensorType>();
if (op0Ty == nullptr || op1Ty == nullptr) {
op->emitOpError() << "should have both operands as tensor";
return mlir::failure();
}
auto el0Ty =
op0Ty.getElementType()
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
.dyn_cast_or_null<mlir::concretelang::FHE::FheIntegerInterface>();
if (el0Ty == nullptr) {
op->emitOpError() << "should have a !FHE.eint as the element type of the "
"tensor of operand #0";
op->emitOpError()
<< "should have !FHE.eint or !FHE.esint as the element type of the "
"tensor of operand #0";
return mlir::failure();
}
auto el1Ty =
op1Ty.getElementType()
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
.dyn_cast_or_null<mlir::concretelang::FHE::FheIntegerInterface>();
if (el1Ty == nullptr) {
op->emitOpError() << "should have a !FHE.eint as the element type of the "
"tensor of operand #1";
op->emitOpError()
<< "should have !FHE.eint or !FHE.esint as the element type of the "
"tensor of operand #1";
return mlir::failure();
}
if (el1Ty.getWidth() != el0Ty.getWidth()) {
if (el0Ty.isSigned() != el1Ty.isSigned()) {
op->emitOpError()
<< "should have the signedness of encrypted arguments equal";
return mlir::failure();
}
unsigned el0BitWidth = el0Ty.getWidth();
unsigned el1BitWidth = el1Ty.getWidth();
if (el1BitWidth != el0BitWidth) {
op->emitOpError() << "should have the width of encrypted equals"
", got "
<< el1Ty.getWidth() << " expect " << el0Ty.getWidth();
<< el1BitWidth << " expect " << el0BitWidth;
return mlir::failure();
}
return mlir::success();
}
@@ -214,19 +242,23 @@ LogicalResult verifyTensorUnaryEint(mlir::Operation *op) {
op->emitOpError() << "should have exactly 1 operands";
return mlir::failure();
}
auto op0Ty = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
if (op0Ty == nullptr) {
op->emitOpError() << "should have operand as tensor";
return mlir::failure();
}
auto el0Ty =
op0Ty.getElementType()
.dyn_cast_or_null<mlir::concretelang::FHE::EncryptedIntegerType>();
.dyn_cast_or_null<mlir::concretelang::FHE::FheIntegerInterface>();
if (el0Ty == nullptr) {
op->emitOpError() << "should have a !FHE.eint as the element type of the "
"tensor operand";
op->emitOpError()
<< "should have !FHE.eint or !FHE.esint as the element type of the "
"tensor operand";
return mlir::failure();
}
return mlir::success();
}
@@ -377,14 +409,14 @@ mlir::LogicalResult ApplyMappedLookupTableEintOp::verify() {
.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<FHE::EncryptedIntegerType>();
.dyn_cast<FHE::FheIntegerInterface>();
auto rhsEltType = this->rhs()
.getType()
.cast<mlir::TensorType>()
.getElementType()
.cast<mlir::IntegerType>();
auto resultType =
this->getResult().getType().cast<FHE::EncryptedIntegerType>();
this->getResult().getType().dyn_cast<FHE::FheIntegerInterface>();
if (!mlir::concretelang::FHE::
verifyEncryptedIntegerAndIntegerInputsConsistency(
*this->getOperation(), lhsEltType, rhsEltType)) {
@@ -430,16 +462,15 @@ mlir::LogicalResult SumOp::verify() {
mlir::Value output = this->getResult();
auto inputType = input.getType().dyn_cast<mlir::TensorType>();
mlir::Type outputType = output.getType();
Type outputType = output.getType();
FHE::EncryptedIntegerType inputElementType =
inputType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
FHE::EncryptedIntegerType outputElementType =
!outputType.isa<mlir::TensorType>()
? outputType.dyn_cast<FHE::EncryptedIntegerType>()
: outputType.dyn_cast<mlir::TensorType>()
.getElementType()
.dyn_cast<FHE::EncryptedIntegerType>();
auto inputElementType =
inputType.getElementType().dyn_cast<FHE::FheIntegerInterface>();
auto outputElementType = !outputType.isa<mlir::TensorType>()
? outputType.dyn_cast<FHE::FheIntegerInterface>()
: outputType.dyn_cast<mlir::TensorType>()
.getElementType()
.dyn_cast<FHE::FheIntegerInterface>();
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(
*this->getOperation(), inputElementType, outputElementType)) {
@@ -517,7 +548,7 @@ mlir::LogicalResult ConcatOp::verify() {
auto outVectorType = out.getType().dyn_cast<mlir::TensorType>();
auto outElementType =
outVectorType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
outVectorType.getElementType().dyn_cast<FHE::FheIntegerInterface>();
llvm::ArrayRef<int64_t> outShape = outVectorType.getShape();
size_t outDims = outShape.size();
@@ -533,7 +564,7 @@ mlir::LogicalResult ConcatOp::verify() {
for (mlir::Value in : this->ins()) {
auto inVectorType = in.getType().dyn_cast<mlir::TensorType>();
auto inElementType =
inVectorType.getElementType().dyn_cast<FHE::EncryptedIntegerType>();
inVectorType.getElementType().dyn_cast<FHE::FheIntegerInterface>();
if (!FHE::verifyEncryptedIntegerInputAndResultConsistency(
*this->getOperation(), inElementType, outElementType)) {
return ::mlir::failure();
@@ -827,9 +858,11 @@ mlir::LogicalResult Conv2dOp::verify() {
auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();
auto p = inputTy.getElementType()
.cast<mlir::concretelang::FHE::EncryptedIntegerType>()
.getWidth();
Type inputElTy = inputTy.getElementType();
auto p = inputElTy.isa<FHE::EncryptedIntegerType>()
? inputElTy.cast<FHE::EncryptedIntegerType>().getWidth()
: inputElTy.cast<FHE::EncryptedSignedIntegerType>().getWidth();
auto weightElementTyWidth =
weightTy.getElementType().cast<mlir::IntegerType>().getWidth();
if (weightElementTyWidth != p + 1) {
@@ -1068,6 +1101,62 @@ mlir::LogicalResult TransposeOp::verify() {
return mlir::success();
}
mlir::LogicalResult ToSignedOp::verify() {
auto inputType = this->input().getType().cast<mlir::ShapedType>();
auto outputType = this->getResult().getType().cast<mlir::ShapedType>();
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
if (inputShape != outputShape) {
this->emitOpError()
<< "input and output tensors should have the same shape";
return mlir::failure();
}
auto inputElementType =
inputType.getElementType().cast<FHE::EncryptedIntegerType>();
auto outputElementType =
outputType.getElementType().cast<FHE::EncryptedSignedIntegerType>();
if (inputElementType.getWidth() != outputElementType.getWidth()) {
this->emitOpError()
<< "input and output tensors should have the same width";
return mlir::failure();
}
return mlir::success();
}
mlir::LogicalResult ToUnsignedOp::verify() {
mlir::ShapedType inputType =
this->input().getType().dyn_cast_or_null<mlir::ShapedType>();
mlir::ShapedType outputType =
this->getResult().getType().dyn_cast_or_null<mlir::ShapedType>();
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
llvm::ArrayRef<int64_t> outputShape = outputType.getShape();
if (inputShape != outputShape) {
this->emitOpError()
<< "input and output tensors should have the same shape";
return mlir::failure();
}
auto inputElementType =
inputType.getElementType().cast<FHE::EncryptedSignedIntegerType>();
auto outputElementType =
outputType.getElementType().cast<FHE::EncryptedIntegerType>();
if (inputElementType.getWidth() != outputElementType.getWidth()) {
this->emitOpError()
<< "input and output tensors should have the same width";
return mlir::failure();
}
return mlir::success();
}
/// Avoid addition with constant tensor of 0s
OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);

View File

@@ -0,0 +1,18 @@
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @main(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.eint<2>>) outs(%0 : tensor<2x3x4x!FHE.esint<2>>) {
// CHECK-NEXT: ^bb0(%arg1: !FHE.eint<2>, %arg2: !FHE.esint<2>):
// CHECK-NEXT: %2 = "FHE.to_signed"(%arg1) : (!FHE.eint<2>) -> !FHE.esint<2>
// CHECK-NEXT: linalg.yield %2 : !FHE.esint<2>
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.esint<2>>
// CHECK-NEXT: }
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x3x4x!FHE.eint<2>>) -> tensor<2x3x4x!FHE.esint<2>> {
%1 = "FHELinalg.to_signed"(%arg0): (tensor<2x3x4x!FHE.eint<2>>) -> (tensor<2x3x4x!FHE.esint<2>>)
return %1: tensor<2x3x4x!FHE.esint<2>>
}

View File

@@ -0,0 +1,18 @@
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
// CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-NEXT: module {
// CHECK-NEXT: func.func @main(%arg0: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
// CHECK-NEXT: %0 = bufferization.alloc_tensor() : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<2x3x4x!FHE.esint<2>>) outs(%0 : tensor<2x3x4x!FHE.eint<2>>) {
// CHECK-NEXT: ^bb0(%arg1: !FHE.esint<2>, %arg2: !FHE.eint<2>):
// CHECK-NEXT: %2 = "FHE.to_unsigned"(%arg1) : (!FHE.esint<2>) -> !FHE.eint<2>
// CHECK-NEXT: linalg.yield %2 : !FHE.eint<2>
// CHECK-NEXT: } -> tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: return %1 : tensor<2x3x4x!FHE.eint<2>>
// CHECK-NEXT: }
// CHECK-NEXT: }
func.func @main(%arg0: tensor<2x3x4x!FHE.esint<2>>) -> tensor<2x3x4x!FHE.eint<2>> {
%1 = "FHELinalg.to_unsigned"(%arg0): (tensor<2x3x4x!FHE.esint<2>>) -> (tensor<2x3x4x!FHE.eint<2>>)
return %1: tensor<2x3x4x!FHE.eint<2>>
}

View File

@@ -0,0 +1,31 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.add_eint' op should have the width of encrypted inputs equal
func.func @bad_inputs_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// -----
// CHECK-LABEL: error: 'FHE.add_eint' op should have the signedness of encrypted inputs equal
func.func @bad_inputs_signedness(%arg0: !FHE.eint<2>, %arg1: !FHE.esint<2>) -> !FHE.eint<2> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.esint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// -----
// CHECK-LABEL: error: 'FHE.add_eint' op should have the width of encrypted inputs and result equal
func.func @bad_result_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}
// -----
// CHECK-LABEL: error: 'FHE.add_eint' op should have the signedness of encrypted inputs and result equal
func.func @bad_result_signedness(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.esint<2> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}

View File

@@ -0,0 +1,26 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the width of plain input equal to width of encrypted input + 1
func.func @bad_clear_width(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 1 : i4
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i4) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// -----
// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the width of encrypted inputs and result equal
func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.eint<3> {
%0 = arith.constant 1 : i3
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}
// -----
// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the signedness of encrypted inputs and result equal
func.func @bad_result_signedness(%arg0: !FHE.eint<2>) -> !FHE.esint<2> {
%0 = arith.constant 1 : i3
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}

View File

@@ -1,6 +1,13 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: FHE.eint didn't support precision equals to 0
// CHECK-LABEL: FHE.eint doesn't support precision of 0
func.func @test(%arg0: !FHE.eint<0>) {
return
}
// -----
// CHECK-LABEL: FHE.esint doesn't support precision of 0
func.func @test_signed(%arg0: !FHE.esint<0>) {
return
}

View File

@@ -0,0 +1,26 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the width of plain input equal to width of encrypted input + 1
func.func @bad_clear_width(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 1 : i4
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i4) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// -----
// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the width of encrypted inputs and result equal
func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.eint<3> {
%0 = arith.constant 1 : i3
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}
// -----
// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the signedness of encrypted inputs and result equal
func.func @bad_result_signedness(%arg0: !FHE.eint<2>) -> !FHE.esint<2> {
%0 = arith.constant 1 : i3
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}

View File

@@ -0,0 +1,15 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.neg_eint' op should have the width of encrypted inputs and result equal
func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.eint<3> {
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<2>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}
// -----
// CHECK-LABEL: error: 'FHE.neg_eint' op should have the signedness of encrypted inputs and result equal
func.func @bad_result_signedness(%arg0: !FHE.eint<2>) -> !FHE.esint<2> {
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<2>) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}

View File

@@ -1,7 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.add_eint' op should have the width of encrypted inputs equals
func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}

View File

@@ -1,7 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.add_eint' op should have the width of encrypted inputs and result equals
func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -1,8 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 1 : i4
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i4) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}

View File

@@ -1,8 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.add_eint_int' op should have the width of encrypted inputs and result equals
func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<3> {
%0 = arith.constant 1 : i2
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.eint<2>, i2) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -1,8 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the width of plain input equals to width of encrypted input + 1
func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 1 : i4
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i4) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}

View File

@@ -1,8 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.mul_eint_int' op should have the width of encrypted inputs and result equals
func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<3> {
%0 = arith.constant 1 : i2
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.eint<2>, i2) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -1,7 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.neg_eint' op should have the width of encrypted inputs and result equals
func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<3> {
%1 = "FHE.neg_eint"(%arg0): (!FHE.eint<2>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -1,8 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the width of plain input equals to width of encrypted input + 1
func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 1 : i4
%1 = "FHE.sub_int_eint"(%0, %arg0): (i4, !FHE.eint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}

View File

@@ -1,8 +0,0 @@
// RUN: not concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the width of encrypted inputs and result equals
func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<3> {
%0 = arith.constant 1 : i2
%1 = "FHE.sub_int_eint"(%0, %arg0): (i2, !FHE.eint<2>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}

View File

@@ -9,6 +9,15 @@ func.func @zero() -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK: func.func @zero_signed() -> !FHE.esint<2>
func.func @zero_signed() -> !FHE.esint<2> {
// CHECK-NEXT: %[[RET:.*]] = "FHE.zero"() : () -> !FHE.esint<2>
// CHECK-NEXT: return %[[RET]] : !FHE.esint<2>
%1 = "FHE.zero"() : () -> !FHE.esint<2>
return %1: !FHE.esint<2>
}
// CHECK: func.func @zero_1D() -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x!FHE.eint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<2>>
@@ -18,6 +27,15 @@ func.func @zero_1D() -> tensor<4x!FHE.eint<2>> {
return %0 : tensor<4x!FHE.eint<2>>
}
// CHECK: func.func @zero_1D_signed() -> tensor<4x!FHE.esint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x!FHE.esint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.esint<2>>
// CHECK-NEXT: }
func.func @zero_1D_signed() -> tensor<4x!FHE.esint<2>> {
%0 = "FHE.zero_tensor"() : () -> tensor<4x!FHE.esint<2>>
return %0 : tensor<4x!FHE.esint<2>>
}
// CHECK: func.func @zero_2D() -> tensor<4x9x!FHE.eint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x9x!FHE.eint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<4x9x!FHE.eint<2>>
@@ -27,6 +45,15 @@ func.func @zero_2D() -> tensor<4x9x!FHE.eint<2>> {
return %0 : tensor<4x9x!FHE.eint<2>>
}
// CHECK: func.func @zero_2D_signed() -> tensor<4x9x!FHE.esint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero_tensor"() : () -> tensor<4x9x!FHE.esint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<4x9x!FHE.esint<2>>
// CHECK-NEXT: }
func.func @zero_2D_signed() -> tensor<4x9x!FHE.esint<2>> {
%0 = "FHE.zero_tensor"() : () -> tensor<4x9x!FHE.esint<2>>
return %0 : tensor<4x9x!FHE.esint<2>>
}
// CHECK-LABEL: func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
@@ -38,6 +65,35 @@ func.func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func.func @add_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2>
func.func @add_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
// CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%arg0, %[[V1]]) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
// CHECK-NEXT: return %[[V2]] : !FHE.esint<2>
%0 = arith.constant 1 : i3
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.esint<2>, i3) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}
// CHECK-LABEL: func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2>
func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: return %[[V1]] : !FHE.eint<2>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func.func @add_eint_signed(%arg0: !FHE.esint<2>, %arg1: !FHE.esint<2>) -> !FHE.esint<2>
func.func @add_eint_signed(%arg0: !FHE.esint<2>, %arg1: !FHE.esint<2>) -> !FHE.esint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.esint<2>, !FHE.esint<2>) -> !FHE.esint<2>
// CHECK-NEXT: return %[[V1]] : !FHE.esint<2>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}
// CHECK-LABEL: func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
@@ -49,6 +105,17 @@ func.func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func.func @sub_int_eint_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2>
func.func @sub_int_eint_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
// CHECK-NEXT: %[[V2:.*]] = "FHE.sub_int_eint"(%[[V1]], %arg0) : (i3, !FHE.esint<2>) -> !FHE.esint<2>
// CHECK-NEXT: return %[[V2]] : !FHE.esint<2>
%0 = arith.constant 1 : i3
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.esint<2>) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}
// CHECK-LABEL: func.func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func.func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
@@ -60,6 +127,17 @@ func.func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func.func @sub_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2>
func.func @sub_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
// CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint_int"(%arg0, %[[V1]]) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
// CHECK-NEXT: return %[[V2]] : !FHE.esint<2>
%0 = arith.constant 1 : i3
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.esint<2>, i3) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}
// CHECK-LABEL: func.func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2>
func.func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
@@ -69,6 +147,15 @@ func.func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func.func @sub_eint_signed(%arg0: !FHE.esint<2>, %arg1: !FHE.esint<2>) -> !FHE.esint<2>
func.func @sub_eint_signed(%arg0: !FHE.esint<2>, %arg1: !FHE.esint<2>) -> !FHE.esint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%arg0, %arg1) : (!FHE.esint<2>, !FHE.esint<2>) -> !FHE.esint<2>
// CHECK-NEXT: return %[[V1]] : !FHE.esint<2>
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.esint<2>, !FHE.esint<2>) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}
// CHECK-LABEL: func.func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func.func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.neg_eint"(%arg0) : (!FHE.eint<2>) -> !FHE.eint<2>
@@ -89,12 +176,32 @@ func.func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2>
func.func @add_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-LABEL: func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2>
func.func @mul_eint_int_signed(%arg0: !FHE.esint<2>) -> !FHE.esint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%arg0, %[[V1]]) : (!FHE.esint<2>, i3) -> !FHE.esint<2>
// CHECK-NEXT: return %[[V2]] : !FHE.esint<2>
%0 = arith.constant 1 : i3
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.esint<2>, i3) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}
// CHECK-LABEL: func.func @to_signed(%arg0: !FHE.eint<2>) -> !FHE.esint<2>
func.func @to_signed(%arg0: !FHE.eint<2>) -> !FHE.esint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.to_signed"(%arg0) : (!FHE.eint<2>) -> !FHE.esint<2>
// CHECK-NEXT: return %[[V1]] : !FHE.esint<2>
%1 = "FHE.to_signed"(%arg0): (!FHE.eint<2>) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}
// CHECK-LABEL: func.func @to_unsigned(%arg0: !FHE.esint<2>) -> !FHE.eint<2>
func.func @to_unsigned(%arg0: !FHE.esint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.to_unsigned"(%arg0) : (!FHE.esint<2>) -> !FHE.eint<2>
// CHECK-NEXT: return %[[V1]] : !FHE.eint<2>
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
%1 = "FHE.to_unsigned"(%arg0): (!FHE.esint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}

View File

@@ -0,0 +1,31 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.sub_eint' op should have the width of encrypted inputs equal
func.func @bad_inputs_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<3>) -> !FHE.eint<2> {
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// -----
// CHECK-LABEL: error: 'FHE.sub_eint' op should have the signedness of encrypted inputs equal
func.func @bad_inputs_signedness(%arg0: !FHE.eint<2>, %arg1: !FHE.esint<2>) -> !FHE.eint<2> {
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.esint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// -----
// CHECK-LABEL: error: 'FHE.sub_eint' op should have the width of encrypted inputs and result equal
func.func @bad_result_width(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<3> {
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}
// -----
// CHECK-LABEL: error: 'FHE.sub_eint' op should have the signedness of encrypted inputs and result equal
func.func @bad_result_signedness(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.esint<2> {
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}

View File

@@ -0,0 +1,26 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the width of plain input equal to width of encrypted input + 1
func.func @bad_clear_width(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 1 : i4
%1 = "FHE.sub_int_eint"(%0, %arg0): (i4, !FHE.eint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// -----
// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the width of encrypted inputs and result equal
func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.eint<3> {
%0 = arith.constant 1 : i3
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.eint<3>)
return %1: !FHE.eint<3>
}
// -----
// CHECK-LABEL: error: 'FHE.sub_int_eint' op should have the signedness of encrypted inputs and result equal
func.func @bad_result_signedness(%arg0: !FHE.eint<2>) -> !FHE.esint<2> {
%0 = arith.constant 1 : i3
%1 = "FHE.sub_int_eint"(%0, %arg0): (i3, !FHE.eint<2>) -> (!FHE.esint<2>)
return %1: !FHE.esint<2>
}

View File

@@ -0,0 +1,7 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.to_signed' op should have the width of encrypted input and result equal
func.func @bad_result_width(%arg0: !FHE.eint<2>) -> !FHE.esint<3> {
%1 = "FHE.to_signed"(%arg0): (!FHE.eint<2>) -> !FHE.esint<3>
return %1: !FHE.esint<3>
}

View File

@@ -0,0 +1,7 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHE.to_unsigned' op should have the width of encrypted input and result equal
func.func @bad_result_width(%arg0: !FHE.esint<2>) -> !FHE.eint<3> {
%1 = "FHE.to_unsigned"(%arg0): (!FHE.esint<2>) -> !FHE.eint<3>
return %1: !FHE.eint<3>
}

View File

@@ -19,7 +19,7 @@ func.func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> {
// -----
func.func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>> {
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equal}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<7>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<6>>
return %0 : tensor<7x!FHE.eint<6>>
}
@@ -27,7 +27,7 @@ func.func @main(%x: tensor<4x!FHE.eint<7>>, %y: tensor<3x!FHE.eint<7>>) -> tenso
// -----
func.func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equal}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<6>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}
@@ -35,7 +35,7 @@ func.func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<6>>) -> tenso
// -----
func.func @main(%x: tensor<4x!FHE.eint<6>>, %y: tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>> {
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equals}}
// expected-error @+1 {{'FHELinalg.concat' op should have the width of encrypted inputs and result equal}}
%0 = "FHELinalg.concat"(%x, %y) : (tensor<4x!FHE.eint<6>>, tensor<3x!FHE.eint<7>>) -> tensor<7x!FHE.eint<7>>
return %0 : tensor<7x!FHE.eint<7>>
}

View File

@@ -47,7 +47,7 @@ func.func @dot_incompatible_return(
%arg0: tensor<4x!FHE.eint<2>>,
%arg1: tensor<4xi3>) -> !FHE.eint<3>
{
// expected-error @+1 {{'FHELinalg.dot_eint_int' op should have the width of encrypted inputs and result equals}}
// expected-error @+1 {{'FHELinalg.dot_eint_int' op should have the width of encrypted inputs and result equal}}
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<3>
@@ -61,7 +61,7 @@ func.func @dot_incompatible_int(
%arg0: tensor<4x!FHE.eint<2>>,
%arg1: tensor<4xi4>) -> !FHE.eint<2>
{
// expected-error @+1 {{'FHELinalg.dot_eint_int' op should have the width of plain input equals to width of encrypted input + 1}}
// expected-error @+1 {{'FHELinalg.dot_eint_int' op should have the width of plain input equal to width of encrypted input + 1}}
%ret = "FHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!FHE.eint<2>>, tensor<4xi4>) -> !FHE.eint<2>

View File

@@ -3,7 +3,7 @@
// -----
func.func @sum_invalid_bitwidth(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<6> {
// expected-error @+1 {{'FHELinalg.sum' op should have the width of encrypted inputs and result equals}}
// expected-error @+1 {{'FHELinalg.sum' op should have the width of encrypted inputs and result equal}}
%1 = "FHELinalg.sum"(%arg0): (tensor<4x!FHE.eint<7>>) -> !FHE.eint<6>
return %1 : !FHE.eint<6>
}

View File

@@ -0,0 +1,15 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHELinalg.to_signed' op input and output tensors should have the same width
func.func @bad_result_width(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<3>> {
%1 = "FHELinalg.to_signed"(%arg0): (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<3>>
return %1: tensor<3x2x!FHE.esint<3>>
}
// -----
// CHECK-LABEL: error: 'FHELinalg.to_signed' op input and output tensors should have the same shape
func.func @bad_result_shape(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> {
%1 = "FHELinalg.to_signed"(%arg0): (tensor<3x2x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>>
return %1: tensor<3x!FHE.esint<2>>
}

View File

@@ -0,0 +1,23 @@
// RUN: concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// -----
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.to_signed"(%[[a0]]) : (tensor<3x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<3x!FHE.esint<2>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>> {
%1 = "FHELinalg.to_signed"(%arg0): (tensor<3x!FHE.eint<2>>) -> tensor<3x!FHE.esint<2>>
return %1 : tensor<3x!FHE.esint<2>>
}
// -----
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.to_signed"(%[[a0]]) : (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<3x2x!FHE.esint<2>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>> {
%1 = "FHELinalg.to_signed"(%arg0): (tensor<3x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.esint<2>>
return %1 : tensor<3x2x!FHE.esint<2>>
}

View File

@@ -0,0 +1,15 @@
// RUN: not concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// CHECK-LABEL: error: 'FHELinalg.to_unsigned' op input and output tensors should have the same width
func.func @bad_result_width(%arg0: tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<3>> {
%1 = "FHELinalg.to_unsigned"(%arg0): (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<3>>
return %1: tensor<3x2x!FHE.eint<3>>
}
// -----
// CHECK-LABEL: error: 'FHELinalg.to_unsigned' op input and output tensors should have the same shape
func.func @bad_result_shape(%arg0: tensor<3x2x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> {
%1 = "FHELinalg.to_unsigned"(%arg0): (tensor<3x2x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>>
return %1: tensor<3x!FHE.eint<2>>
}

View File

@@ -0,0 +1,23 @@
// RUN: concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s
// -----
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.to_unsigned"(%[[a0]]) : (tensor<3x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<3x!FHE.eint<2>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>> {
%1 = "FHELinalg.to_unsigned"(%arg0): (tensor<3x!FHE.esint<2>>) -> tensor<3x!FHE.eint<2>>
return %1 : tensor<3x!FHE.eint<2>>
}
// -----
// CHECK: func.func @main(%[[a0:.*]]: tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>> {
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.to_unsigned"(%[[a0]]) : (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>>
// CHECK-NEXT: return %[[v0]] : tensor<3x2x!FHE.eint<2>>
// CHECK-NEXT: }
func.func @main(%arg0: tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>> {
%1 = "FHELinalg.to_unsigned"(%arg0): (tensor<3x2x!FHE.esint<2>>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}