mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 12:44:57 -05:00
feat: implement advanced sum operation
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
@@ -517,12 +517,21 @@ def ZeroOp : FHELinalg_Op<"zero", []> {
|
||||
}
|
||||
|
||||
def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
|
||||
let summary = "Returns the sum of all elements of a tensor of encrypted integers.";
|
||||
let summary = "Returns the sum of elements of a tensor of encrypted integers along specified axes.";
|
||||
|
||||
let description = [{
|
||||
Performs a sum to a tensor of encrypted integers.
|
||||
Attributes:
|
||||
|
||||
- keep_dims: boolean = false
|
||||
whether to keep the rank of the tensor after the sum operation
|
||||
if true, reduced axes will have the size of 1
|
||||
|
||||
- axes: I64ArrayAttr = []
|
||||
list of dimension to perform the sum along
|
||||
think of it as the dimensions to reduce (see examples below to get an intuition)
|
||||
|
||||
Examples:
|
||||
|
||||
```mlir
|
||||
// Returns the sum of all elements of `%a0`
|
||||
"FHELinalg.sum"(%a0) : (tensor<3x3x!FHE.eint<4>>) -> !FHE.eint<4>
|
||||
@@ -532,13 +541,64 @@ def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
|
||||
// ( [7,8,9] )
|
||||
//
|
||||
```
|
||||
|
||||
```mlir
|
||||
// Returns the sum of all elements of `%a0` along columns
|
||||
"FHELinalg.sum"(%a0) { axes = [0] } : (tensor<3x2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<4>>
|
||||
//
|
||||
// ( [1,2] )
|
||||
// sum ( [3,4] ) = [9, 12]
|
||||
// ( [5,6] )
|
||||
//
|
||||
```
|
||||
|
||||
```mlir
|
||||
// Returns the sum of all elements of `%a0` along columns while preserving dimensions
|
||||
"FHELinalg.sum"(%a0) { axes = [0], keep_dims = true } : (tensor<3x2x!FHE.eint<4>>) -> tensor<1x2x!FHE.eint<4>>
|
||||
//
|
||||
// ( [1,2] )
|
||||
// sum ( [3,4] ) = [[9, 12]]
|
||||
// ( [5,6] )
|
||||
//
|
||||
```
|
||||
|
||||
```mlir
|
||||
// Returns the sum of all elements of `%a0` along rows
|
||||
"FHELinalg.sum"(%a0) { axes = [1] } : (tensor<3x2x!FHE.eint<4>>) -> tensor<3x!FHE.eint<4>>
|
||||
//
|
||||
// ( [1,2] )
|
||||
// sum ( [3,4] ) = [3, 7, 11]
|
||||
// ( [5,6] )
|
||||
//
|
||||
```
|
||||
|
||||
```mlir
|
||||
// Returns the sum of all elements of `%a0` along rows while preserving dimensions
|
||||
"FHELinalg.sum"(%a0) { axes = [1], keep_dims = true } : (tensor<3x2x!FHE.eint<4>>) -> tensor<3x1x!FHE.eint<4>>
|
||||
//
|
||||
// ( [1,2] ) [3]
|
||||
// sum ( [3,4] ) = [7]
|
||||
// ( [5,6] ) [11]
|
||||
//
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor,
|
||||
DefaultValuedAttr<I64ArrayAttr, "{}">:$axes,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$keep_dims
|
||||
);
|
||||
|
||||
let results = (outs EncryptedIntegerType:$out);
|
||||
let results = (outs
|
||||
TypeConstraint<Or<[
|
||||
EncryptedIntegerType.predicate,
|
||||
And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>
|
||||
]>>:$out
|
||||
);
|
||||
|
||||
let verifier = [{
|
||||
return mlir::concretelang::FHELinalg::verifySum(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user