feat: implement advanced sum operation

This commit is contained in:
Umut
2022-02-08 13:14:42 +03:00
parent dddad849c7
commit a1818a3fd9
15 changed files with 2922 additions and 223 deletions

View File

@@ -8,6 +8,7 @@
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include <mlir/IR/BuiltinOps.h>

View File

@@ -517,12 +517,21 @@ def ZeroOp : FHELinalg_Op<"zero", []> {
}
def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
let summary = "Returns the sum of all elements of a tensor of encrypted integers.";
let summary = "Returns the sum of elements of a tensor of encrypted integers along specified axes.";
let description = [{
Performs a sum to a tensor of encrypted integers.
Attributes:
- keep_dims: boolean = false
whether to keep the rank of the tensor after the sum operation
if true, reduced axes will have the size of 1
- axes: I64ArrayAttr = []
list of dimension to perform the sum along
think of it as the dimensions to reduce (see examples below to get an intuition)
Examples:
```mlir
// Returns the sum of all elements of `%a0`
"FHELinalg.sum"(%a0) : (tensor<3x3x!FHE.eint<4>>) -> !FHE.eint<4>
@@ -532,13 +541,64 @@ def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
// ( [7,8,9] )
//
```
```mlir
// Returns the sum of all elements of `%a0` along columns
"FHELinalg.sum"(%a0) { axes = [0] } : (tensor<3x2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<4>>
//
// ( [1,2] )
// sum ( [3,4] ) = [9, 12]
// ( [5,6] )
//
```
```mlir
// Returns the sum of all elements of `%a0` along columns while preserving dimensions
"FHELinalg.sum"(%a0) { axes = [0], keep_dims = true } : (tensor<3x2x!FHE.eint<4>>) -> tensor<1x2x!FHE.eint<4>>
//
// ( [1,2] )
// sum ( [3,4] ) = [[9, 12]]
// ( [5,6] )
//
```
```mlir
// Returns the sum of all elements of `%a0` along rows
"FHELinalg.sum"(%a0) { axes = [1] } : (tensor<3x2x!FHE.eint<4>>) -> tensor<3x!FHE.eint<4>>
//
// ( [1,2] )
// sum ( [3,4] ) = [3, 7, 11]
// ( [5,6] )
//
```
```mlir
// Returns the sum of all elements of `%a0` along rows while preserving dimensions
"FHELinalg.sum"(%a0) { axes = [1], keep_dims = true } : (tensor<3x2x!FHE.eint<4>>) -> tensor<3x1x!FHE.eint<4>>
//
// ( [1,2] ) [3]
// sum ( [3,4] ) = [7]
// ( [5,6] ) [11]
//
```
}];
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor,
DefaultValuedAttr<I64ArrayAttr, "{}">:$axes,
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
);
let results = (outs EncryptedIntegerType:$out);
let results = (outs
TypeConstraint<Or<[
EncryptedIntegerType.predicate,
And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>
]>>:$out
);
let verifier = [{
return mlir::concretelang::FHELinalg::verifySum(*this);
}];
}
#endif