mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: implement basic sum operation
This commit is contained in:
@@ -516,4 +516,29 @@ def ZeroOp : FHELinalg_Op<"zero", []> {
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$aggregate);
|
||||
}
|
||||
|
||||
def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> {
|
||||
let summary = "Returns the sum of all elements of a tensor of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Performs a sum to a tensor of encrypted integers.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the sum of all elements of `%a0`
|
||||
"FHELinalg.sum"(%a0) : (tensor<3x3x!FHE.eint<4>>) -> !FHE.eint<4>
|
||||
//
|
||||
// ( [1,2,3] )
|
||||
// sum ( [4,5,6] ) = 45
|
||||
// ( [7,8,9] )
|
||||
//
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$tensor
|
||||
);
|
||||
|
||||
let results = (outs EncryptedIntegerType:$out);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -961,6 +961,128 @@ struct FHELinalgZeroToLinalgGenerate
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.zero` to an instance of `linalg.generate` with an
|
||||
// appropriate region yielding a zero value.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// %result = "FHELinalg.sum"(%input) :
|
||||
// tensor<d0xd1x...xdNx!FHE.eint<p>>() -> !FHE.eint<p>
|
||||
//
|
||||
// becomes:
|
||||
//
|
||||
// #map0 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)>
|
||||
// #map1 = affine_map<(i0, i1, ..., iN) -> (0)>
|
||||
//
|
||||
// %zero = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// %accumulator = tensor.from_elements %zero : tensor<1x!FHE.eint<7>>
|
||||
//
|
||||
// %accumulation = linalg.generic
|
||||
// { indexing_maps = [#map0, #map1], iterator_types = ["reduction",
|
||||
// "reduction", ..., "reduction"] } ins(%input :
|
||||
// tensor<d0xd1x...xdNx!FHE.eint<7>>) outs(%accumulator :
|
||||
// tensor<1x!FHE.eint<7>>)
|
||||
// {
|
||||
// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>):
|
||||
// %c = "FHE.add_eint"(%a, %b) :
|
||||
// (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
|
||||
// linalg.yield %c : !FHE.eint<7>
|
||||
// } -> tensor<1x!FHE.eint<7>>
|
||||
//
|
||||
// %index = arith.constant 0 : index
|
||||
// %result = tensor.extract %index : tensor<1x!FHE.eint<7>>
|
||||
//
|
||||
struct SumToLinalgGeneric
|
||||
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::SumOp> {
|
||||
SumToLinalgGeneric(::mlir::MLIRContext *context)
|
||||
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::SumOp>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(::mlir::concretelang::FHELinalg::SumOp sumOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
namespace arith = mlir::arith;
|
||||
namespace linalg = mlir::linalg;
|
||||
namespace tensor = mlir::tensor;
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
|
||||
mlir::Location location = sumOp.getLoc();
|
||||
|
||||
mlir::Value input = sumOp.getOperand();
|
||||
mlir::Value output = sumOp.getResult();
|
||||
|
||||
auto inputType = input.getType().dyn_cast_or_null<mlir::TensorType>();
|
||||
assert(inputType != nullptr);
|
||||
|
||||
llvm::ArrayRef<int64_t> inputShape = inputType.getShape();
|
||||
size_t inputDimensions = inputShape.size();
|
||||
|
||||
mlir::Value zero =
|
||||
rewriter.create<FHE::ZeroEintOp>(location, output.getType())
|
||||
.getResult();
|
||||
|
||||
for (size_t i = 0; i < inputDimensions; i++) {
|
||||
if (inputShape[i] == 0) {
|
||||
rewriter.replaceOp(sumOp, {zero});
|
||||
return mlir::success();
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Value accumulator =
|
||||
rewriter.create<tensor::FromElementsOp>(location, zero).getResult();
|
||||
|
||||
auto ins = llvm::SmallVector<mlir::Value, 1>{input};
|
||||
auto outs = llvm::SmallVector<mlir::Value, 1>{accumulator};
|
||||
|
||||
mlir::AffineMap inputMap = mlir::AffineMap::getMultiDimIdentityMap(
|
||||
inputDimensions, this->getContext());
|
||||
|
||||
mlir::AffineMap outputMap = mlir::AffineMap::get(
|
||||
inputDimensions, 0, {rewriter.getAffineConstantExpr(0)},
|
||||
rewriter.getContext());
|
||||
|
||||
auto maps = llvm::SmallVector<mlir::AffineMap, 2>{inputMap, outputMap};
|
||||
|
||||
auto iteratorTypes = llvm::SmallVector<llvm::StringRef, 3>{};
|
||||
for (size_t i = 0; i < inputDimensions; i++) {
|
||||
iteratorTypes.push_back("reduction");
|
||||
}
|
||||
|
||||
auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
mlir::Value lhs = blockArgs[0];
|
||||
mlir::Value rhs = blockArgs[1];
|
||||
|
||||
mlir::Value addition =
|
||||
nestedBuilder.create<FHE::AddEintOp>(location, lhs, rhs).getResult();
|
||||
|
||||
nestedBuilder.create<linalg::YieldOp>(location, addition);
|
||||
};
|
||||
|
||||
auto resultTypes = llvm::SmallVector<mlir::Type, 1>{accumulator.getType()};
|
||||
mlir::Value accumulation =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(location, resultTypes, ins, outs, maps,
|
||||
iteratorTypes, regionBuilder)
|
||||
.getResult(0);
|
||||
|
||||
mlir::Value index =
|
||||
rewriter.create<arith::ConstantIndexOp>(location, 0).getResult();
|
||||
auto indices = llvm::SmallVector<mlir::Value, 1>{index};
|
||||
|
||||
mlir::Value result =
|
||||
rewriter.create<tensor::ExtractOp>(location, accumulation, indices)
|
||||
.getResult();
|
||||
rewriter.replaceOp(sumOp, {result});
|
||||
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct FHETensorOpsToLinalg
|
||||
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
|
||||
@@ -1020,6 +1142,7 @@ void FHETensorOpsToLinalg::runOnFunction() {
|
||||
patterns.insert<FHELinalgApplyMappedLookupTableToLinalgGeneric>(
|
||||
&getContext());
|
||||
patterns.insert<FHELinalgZeroToLinalgGenerate>(&getContext());
|
||||
patterns.insert<SumToLinalgGeneric>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
@@ -837,6 +837,29 @@ static llvm::APInt getSqMANP(
|
||||
return operandMANPs[0]->getValue().getMANP().getValue();
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::SumOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
auto type = op->getOperand(0).getType().dyn_cast_or_null<mlir::TensorType>();
|
||||
|
||||
uint64_t numberOfElements = type.getNumElements();
|
||||
if (numberOfElements == 0) {
|
||||
return llvm::APInt{1, 1, false};
|
||||
}
|
||||
|
||||
assert(operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
unsigned int multiplierBits = ceilLog2(numberOfElements + 1);
|
||||
auto multiplier = llvm::APInt{multiplierBits, numberOfElements, false};
|
||||
|
||||
return APIntWidthExtendUMul(multiplier, operandMANP);
|
||||
}
|
||||
|
||||
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
|
||||
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
|
||||
@@ -909,6 +932,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = llvm::APInt{1, 1, false};
|
||||
} else if (auto sumOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::SumOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(sumOp, operands);
|
||||
}
|
||||
// Tensor Operators
|
||||
// ExtractOp
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)>
|
||||
|
||||
// CHECK: func @sum_1D(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
|
||||
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v3]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: }
|
||||
func @sum_1D(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0)>
|
||||
|
||||
// CHECK: func @sum_2D(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
|
||||
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v3]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: }
|
||||
func @sum_2D(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0)>
|
||||
|
||||
// CHECK: func @sum_3D(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) {
|
||||
// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>):
|
||||
// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7>
|
||||
// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index
|
||||
// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>>
|
||||
// CHECK-NEXT: return %[[v3]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: }
|
||||
func @sum_3D(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
|
||||
|
||||
// CHECK: func @sum_empty(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7>
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: }
|
||||
func @sum_empty(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
@@ -227,7 +227,7 @@ func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// ceil(sqrt(65)) = 9
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -310,7 +310,7 @@ func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// ceil(sqrt(65)) = 9
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -390,3 +390,25 @@ func @zero() -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sum() -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<7>>
|
||||
// CHECK: "FHELinalg.sum"(%0) {MANP = 2 : ui{{[0-9]+}}} : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
%1 = "FHELinalg.sum"(%0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
|
||||
%2 = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<7>>
|
||||
// CHECK: "FHELinalg.sum"(%2) {MANP = 3 : ui{{[0-9]+}}} : (tensor<5x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
%3 = "FHELinalg.sum"(%2) : (tensor<5x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
|
||||
%4 = "FHELinalg.zero"() : () -> tensor<9x!FHE.eint<7>>
|
||||
// CHECK: "FHELinalg.sum"(%4) {MANP = 3 : ui{{[0-9]+}}} : (tensor<9x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
%5 = "FHELinalg.sum"(%4) : (tensor<9x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
|
||||
%6 = "FHELinalg.zero"() : () -> tensor<10x!FHE.eint<7>>
|
||||
// CHECK: "FHELinalg.sum"(%6) {MANP = 4 : ui{{[0-9]+}}} : (tensor<10x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
%7 = "FHELinalg.sum"(%6) : (tensor<10x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
|
||||
return %7 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
@@ -367,3 +367,43 @@ func @zero_2D() -> tensor<4x9x!FHE.eint<2>> {
|
||||
%0 = "FHELinalg.zero"() : () -> tensor<4x9x!FHE.eint<2>>
|
||||
return %0 : tensor<4x9x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.sum
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// CHECK: func @sum_empty(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: }
|
||||
func @sum_empty(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
// CHECK: func @sum_1D(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: }
|
||||
func @sum_1D(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
// CHECK: func @sum_2D(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: }
|
||||
func @sum_2D(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
// CHECK: func @sum_3D(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
// CHECK-NEXT: return %[[v0]] : !FHE.eint<7>
|
||||
// CHECK-NEXT: }
|
||||
func @sum_3D(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7>
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
@@ -1579,6 +1579,152 @@ func @main() -> tensor<2x2x4x!FHE.eint<6>> {
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FHELinalg sum /////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sum_empty) {
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::Expected;
|
||||
|
||||
using mlir::concretelang::IntLambdaArgument;
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
using mlir::concretelang::TensorLambdaArgument;
|
||||
|
||||
JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
|
||||
func @main(%x: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%x) : (tensor<0x!FHE.eint<7>>) -> (!FHE.eint<7>)
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
)XXX");
|
||||
|
||||
const uint8_t expected = 0;
|
||||
|
||||
ArrayRef<uint8_t> xRef(nullptr, (size_t)0);
|
||||
TensorLambdaArgument<IntLambdaArgument<uint8_t>> xArg(xRef, {0});
|
||||
|
||||
Expected<uint64_t> result = lambda.operator()<uint64_t>({&xArg});
|
||||
ASSERT_EXPECTED_SUCCESS(result);
|
||||
|
||||
ASSERT_EQ(*result, expected);
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sum_1D) {
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::Expected;
|
||||
|
||||
using mlir::concretelang::IntLambdaArgument;
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
using mlir::concretelang::TensorLambdaArgument;
|
||||
|
||||
JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
|
||||
func @main(%x: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%x) : (tensor<4x!FHE.eint<7>>) -> (!FHE.eint<7>)
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[4]{0, 1, 2, 3};
|
||||
const uint8_t expected = 6;
|
||||
|
||||
ArrayRef<uint8_t> xRef((const uint8_t *)x, 4);
|
||||
TensorLambdaArgument<IntLambdaArgument<uint8_t>> xArg(xRef, {4});
|
||||
|
||||
Expected<uint64_t> result = lambda.operator()<uint64_t>({&xArg});
|
||||
ASSERT_EXPECTED_SUCCESS(result);
|
||||
|
||||
ASSERT_EQ(*result, expected);
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sum_2D) {
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::Expected;
|
||||
|
||||
using mlir::concretelang::IntLambdaArgument;
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
using mlir::concretelang::TensorLambdaArgument;
|
||||
|
||||
JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
|
||||
func @main(%x: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%x) : (tensor<3x4x!FHE.eint<7>>) -> (!FHE.eint<7>)
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[3][4]{
|
||||
{0, 1, 2, 3},
|
||||
{4, 5, 6, 7},
|
||||
{8, 9, 0, 1},
|
||||
};
|
||||
const uint8_t expected = 46;
|
||||
|
||||
ArrayRef<uint8_t> xRef((const uint8_t *)x, 3 * 4);
|
||||
TensorLambdaArgument<IntLambdaArgument<uint8_t>> xArg(xRef, {3, 4});
|
||||
|
||||
Expected<uint64_t> result = lambda.operator()<uint64_t>({&xArg});
|
||||
ASSERT_EXPECTED_SUCCESS(result);
|
||||
|
||||
ASSERT_EQ(*result, expected);
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sum_3D) {
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::Expected;
|
||||
|
||||
using mlir::concretelang::IntLambdaArgument;
|
||||
using mlir::concretelang::JitCompilerEngine;
|
||||
using mlir::concretelang::TensorLambdaArgument;
|
||||
|
||||
JitCompilerEngine::Lambda lambda = checkedJit(R"XXX(
|
||||
|
||||
func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> {
|
||||
%0 = "FHELinalg.sum"(%x) : (tensor<3x4x2x!FHE.eint<7>>) -> (!FHE.eint<7>)
|
||||
return %0 : !FHE.eint<7>
|
||||
}
|
||||
|
||||
)XXX");
|
||||
|
||||
const uint8_t x[3][4][2]{
|
||||
{
|
||||
{0, 1},
|
||||
{2, 3},
|
||||
{4, 5},
|
||||
{6, 7},
|
||||
},
|
||||
{
|
||||
{8, 9},
|
||||
{0, 1},
|
||||
{2, 3},
|
||||
{4, 5},
|
||||
},
|
||||
{
|
||||
{6, 7},
|
||||
{8, 9},
|
||||
{0, 1},
|
||||
{2, 3},
|
||||
},
|
||||
};
|
||||
const uint8_t expected = 96;
|
||||
|
||||
ArrayRef<uint8_t> xRef((const uint8_t *)x, 3 * 4 * 2);
|
||||
TensorLambdaArgument<IntLambdaArgument<uint8_t>> xArg(xRef, {3, 4, 2});
|
||||
|
||||
Expected<uint64_t> result = lambda.operator()<uint64_t>({&xArg});
|
||||
ASSERT_EXPECTED_SUCCESS(result);
|
||||
|
||||
ASSERT_EQ(*result, expected);
|
||||
}
|
||||
|
||||
class TiledMatMulParametric
|
||||
: public ::testing::TestWithParam<std::vector<int64_t>> {};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user