From 4203e86998f2c73c9c288567bd5387180ce18db7 Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 1 Feb 2022 11:17:43 +0300 Subject: [PATCH] feat: implement basic sum operation --- .../Dialect/FHELinalg/IR/FHELinalgOps.td | 25 +++ .../TensorOpsToLinalg.cpp | 123 +++++++++++++++ compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 26 ++++ .../FHELinalgToLinalg/sum_1d.mlir | 21 +++ .../FHELinalgToLinalg/sum_2d.mlir | 21 +++ .../FHELinalgToLinalg/sum_3d.mlir | 21 +++ .../FHELinalgToLinalg/sum_empty.mlir | 10 ++ .../Dialect/FHE/FHE/Analysis/MANP_linalg.mlir | 26 +++- .../Dialect/FHELinalg/FHELinalg/ops.mlir | 40 +++++ .../unittest/end_to_end_jit_fhelinalg.cc | 146 ++++++++++++++++++ 10 files changed, 457 insertions(+), 2 deletions(-) create mode 100644 compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_1d.mlir create mode 100644 compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_2d.mlir create mode 100644 compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_3d.mlir create mode 100644 compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_empty.mlir diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index b8a9d3dd2..73e6b3488 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -516,4 +516,29 @@ def ZeroOp : FHELinalg_Op<"zero", []> { let results = (outs Type.predicate, HasStaticShapePred]>>:$aggregate); } +def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> { + let summary = "Returns the sum of all elements of a tensor of encrypted integers."; + + let description = [{ + Performs a sum to a tensor of encrypted integers. + + Examples: + ```mlir + // Returns the sum of all elements of `%a0` + "FHELinalg.sum"(%a0) : (tensor<3x3x!FHE.eint<4>>) -> !FHE.eint<4> + // + // ( [1,2,3] ) + // sum ( [4,5,6] ) = 45 + // ( [7,8,9] ) + // + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$tensor + ); + + let results = (outs EncryptedIntegerType:$out); +} + #endif diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 84ced02ad..d1342351d 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -961,6 +961,128 @@ struct FHELinalgZeroToLinalgGenerate }; }; +// This rewrite pattern transforms any instance of operators +// `FHELinalg.zero` to an instance of `linalg.generate` with an +// appropriate region yielding a zero value. +// +// Example: +// +// %result = "FHELinalg.sum"(%input) : +// tensor>() -> !FHE.eint

+// +// becomes: +// +// #map0 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)> +// #map1 = affine_map<(i0, i1, ..., iN) -> (0)> +// +// %zero = "FHE.zero"() : () -> !FHE.eint<7> +// %accumulator = tensor.from_elements %zero : tensor<1x!FHE.eint<7>> +// +// %accumulation = linalg.generic +// { indexing_maps = [#map0, #map1], iterator_types = ["reduction", +// "reduction", ..., "reduction"] } ins(%input : +// tensor>) outs(%accumulator : +// tensor<1x!FHE.eint<7>>) +// { +// ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>): +// %c = "FHE.add_eint"(%a, %b) : +// (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// linalg.yield %c : !FHE.eint<7> +// } -> tensor<1x!FHE.eint<7>> +// +// %index = arith.constant 0 : index +// %result = tensor.extract %index : tensor<1x!FHE.eint<7>> +// +struct SumToLinalgGeneric + : public ::mlir::OpRewritePattern { + SumToLinalgGeneric(::mlir::MLIRContext *context) + : ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::SumOp>( + context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(::mlir::concretelang::FHELinalg::SumOp sumOp, + ::mlir::PatternRewriter &rewriter) const override { + + namespace arith = mlir::arith; + namespace linalg = mlir::linalg; + namespace tensor = mlir::tensor; + + namespace FHE = mlir::concretelang::FHE; + + mlir::Location location = sumOp.getLoc(); + + mlir::Value input = sumOp.getOperand(); + mlir::Value output = sumOp.getResult(); + + auto inputType = input.getType().dyn_cast_or_null(); + assert(inputType != nullptr); + + llvm::ArrayRef inputShape = inputType.getShape(); + size_t inputDimensions = inputShape.size(); + + mlir::Value zero = + rewriter.create(location, output.getType()) + .getResult(); + + for (size_t i = 0; i < inputDimensions; i++) { + if (inputShape[i] == 0) { + rewriter.replaceOp(sumOp, {zero}); + return mlir::success(); + } + } + + mlir::Value accumulator = + rewriter.create(location, zero).getResult(); + + auto ins = llvm::SmallVector{input}; + auto outs = llvm::SmallVector{accumulator}; + + mlir::AffineMap inputMap = mlir::AffineMap::getMultiDimIdentityMap( + inputDimensions, this->getContext()); + + mlir::AffineMap outputMap = mlir::AffineMap::get( + inputDimensions, 0, {rewriter.getAffineConstantExpr(0)}, + rewriter.getContext()); + + auto maps = llvm::SmallVector{inputMap, outputMap}; + + auto iteratorTypes = llvm::SmallVector{}; + for (size_t i = 0; i < inputDimensions; i++) { + iteratorTypes.push_back("reduction"); + } + + auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + mlir::Value lhs = blockArgs[0]; + mlir::Value rhs = blockArgs[1]; + + mlir::Value addition = + nestedBuilder.create(location, lhs, rhs).getResult(); + + nestedBuilder.create(location, addition); + }; + + auto resultTypes = llvm::SmallVector{accumulator.getType()}; + mlir::Value accumulation = + rewriter + .create(location, resultTypes, ins, outs, maps, + iteratorTypes, regionBuilder) + .getResult(0); + + mlir::Value index = + rewriter.create(location, 0).getResult(); + auto indices = llvm::SmallVector{index}; + + mlir::Value result = + rewriter.create(location, accumulation, indices) + .getResult(); + rewriter.replaceOp(sumOp, {result}); + + return mlir::success(); + }; +}; + namespace { struct FHETensorOpsToLinalg : public FHETensorOpsToLinalgBase { @@ -1020,6 +1142,7 @@ void FHETensorOpsToLinalg::runOnFunction() { patterns.insert( &getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index b473c4b44..043d9b091 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -837,6 +837,29 @@ static llvm::APInt getSqMANP( return operandMANPs[0]->getValue().getMANP().getValue(); } +static llvm::APInt getSqMANP( + mlir::concretelang::FHELinalg::SumOp op, + llvm::ArrayRef *> operandMANPs) { + + auto type = op->getOperand(0).getType().dyn_cast_or_null(); + + uint64_t numberOfElements = type.getNumElements(); + if (numberOfElements == 0) { + return llvm::APInt{1, 1, false}; + } + + assert(operandMANPs.size() == 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().getValue(); + + unsigned int multiplierBits = ceilLog2(numberOfElements + 1); + auto multiplier = llvm::APInt{multiplierBits, numberOfElements, false}; + + return APIntWidthExtendUMul(multiplier, operandMANP); +} + struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; MANPAnalysis(mlir::MLIRContext *ctx, bool debug) @@ -909,6 +932,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>( op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; + } else if (auto sumOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(sumOp, operands); } // Tensor Operators // ExtractOp diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_1d.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_1d.mlir new file mode 100644 index 000000000..2f3d6fc7b --- /dev/null +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_1d.mlir @@ -0,0 +1,21 @@ +// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> + +// CHECK: func @sum_1D(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v3]] : !FHE.eint<7> +// CHECK-NEXT: } +func @sum_1D(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_2d.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_2d.mlir new file mode 100644 index 000000000..9c8af2f59 --- /dev/null +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_2d.mlir @@ -0,0 +1,21 @@ +// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0)> + +// CHECK: func @sum_2D(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v3]] : !FHE.eint<7> +// CHECK-NEXT: } +func @sum_2D(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_3d.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_3d.mlir new file mode 100644 index 000000000..94ad08c88 --- /dev/null +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_3d.mlir @@ -0,0 +1,21 @@ +// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0)> + +// CHECK: func @sum_3D(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v3]] : !FHE.eint<7> +// CHECK-NEXT: } +func @sum_3D(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_empty.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_empty.mlir new file mode 100644 index 000000000..5071cb970 --- /dev/null +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_empty.mlir @@ -0,0 +1,10 @@ +// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s + +// CHECK: func @sum_empty(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @sum_empty(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir index e95c8fa62..b9789e4b0 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir @@ -227,7 +227,7 @@ func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2 // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 // manp(add_eint(mul, acc)) = 64 + 1 = 65 // ceil(sqrt(65)) = 9 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -310,7 +310,7 @@ func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 // manp(add_eint(mul, acc)) = 64 + 1 = 65 // ceil(sqrt(65)) = 9 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -390,3 +390,25 @@ func @zero() -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> } + +// ----- + +func @sum() -> !FHE.eint<7> { + %0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<7>> + // CHECK: "FHELinalg.sum"(%0) {MANP = 2 : ui{{[0-9]+}}} : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + %1 = "FHELinalg.sum"(%0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + + %2 = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<7>> + // CHECK: "FHELinalg.sum"(%2) {MANP = 3 : ui{{[0-9]+}}} : (tensor<5x!FHE.eint<7>>) -> !FHE.eint<7> + %3 = "FHELinalg.sum"(%2) : (tensor<5x!FHE.eint<7>>) -> !FHE.eint<7> + + %4 = "FHELinalg.zero"() : () -> tensor<9x!FHE.eint<7>> + // CHECK: "FHELinalg.sum"(%4) {MANP = 3 : ui{{[0-9]+}}} : (tensor<9x!FHE.eint<7>>) -> !FHE.eint<7> + %5 = "FHELinalg.sum"(%4) : (tensor<9x!FHE.eint<7>>) -> !FHE.eint<7> + + %6 = "FHELinalg.zero"() : () -> tensor<10x!FHE.eint<7>> + // CHECK: "FHELinalg.sum"(%6) {MANP = 4 : ui{{[0-9]+}}} : (tensor<10x!FHE.eint<7>>) -> !FHE.eint<7> + %7 = "FHELinalg.sum"(%6) : (tensor<10x!FHE.eint<7>>) -> !FHE.eint<7> + + return %7 : !FHE.eint<7> +} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir index f718d8f1f..740f38224 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir @@ -367,3 +367,43 @@ func @zero_2D() -> tensor<4x9x!FHE.eint<2>> { %0 = "FHELinalg.zero"() : () -> tensor<4x9x!FHE.eint<2>> return %0 : tensor<4x9x!FHE.eint<2>> } + +///////////////////////////////////////////////// +// FHELinalg.sum +///////////////////////////////////////////////// + +// CHECK: func @sum_empty(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @sum_empty(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// CHECK: func @sum_1D(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @sum_1D(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// CHECK: func @sum_2D(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @sum_2D(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// CHECK: func @sum_3D(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @sum_3D(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} diff --git a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc index d8aa38f1a..32cdc119c 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc @@ -1579,6 +1579,152 @@ func @main() -> tensor<2x2x4x!FHE.eint<6>> { } } +/////////////////////////////////////////////////////////////////////////////// +// FHELinalg sum ///////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_FHELinalg, sum_empty) { + + using llvm::ArrayRef; + using llvm::Expected; + + using mlir::concretelang::IntLambdaArgument; + using mlir::concretelang::JitCompilerEngine; + using mlir::concretelang::TensorLambdaArgument; + + JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%x) : (tensor<0x!FHE.eint<7>>) -> (!FHE.eint<7>) + return %0 : !FHE.eint<7> +} + +)XXX"); + + const uint8_t expected = 0; + + ArrayRef xRef(nullptr, (size_t)0); + TensorLambdaArgument> xArg(xRef, {0}); + + Expected result = lambda.operator()({&xArg}); + ASSERT_EXPECTED_SUCCESS(result); + + ASSERT_EQ(*result, expected); +} + +TEST(End2EndJit_FHELinalg, sum_1D) { + + using llvm::ArrayRef; + using llvm::Expected; + + using mlir::concretelang::IntLambdaArgument; + using mlir::concretelang::JitCompilerEngine; + using mlir::concretelang::TensorLambdaArgument; + + JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%x) : (tensor<4x!FHE.eint<7>>) -> (!FHE.eint<7>) + return %0 : !FHE.eint<7> +} + +)XXX"); + + const uint8_t x[4]{0, 1, 2, 3}; + const uint8_t expected = 6; + + ArrayRef xRef((const uint8_t *)x, 4); + TensorLambdaArgument> xArg(xRef, {4}); + + Expected result = lambda.operator()({&xArg}); + ASSERT_EXPECTED_SUCCESS(result); + + ASSERT_EQ(*result, expected); +} + +TEST(End2EndJit_FHELinalg, sum_2D) { + + using llvm::ArrayRef; + using llvm::Expected; + + using mlir::concretelang::IntLambdaArgument; + using mlir::concretelang::JitCompilerEngine; + using mlir::concretelang::TensorLambdaArgument; + + JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%x) : (tensor<3x4x!FHE.eint<7>>) -> (!FHE.eint<7>) + return %0 : !FHE.eint<7> +} + +)XXX"); + + const uint8_t x[3][4]{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 0, 1}, + }; + const uint8_t expected = 46; + + ArrayRef xRef((const uint8_t *)x, 3 * 4); + TensorLambdaArgument> xArg(xRef, {3, 4}); + + Expected result = lambda.operator()({&xArg}); + ASSERT_EXPECTED_SUCCESS(result); + + ASSERT_EQ(*result, expected); +} + +TEST(End2EndJit_FHELinalg, sum_3D) { + + using llvm::ArrayRef; + using llvm::Expected; + + using mlir::concretelang::IntLambdaArgument; + using mlir::concretelang::JitCompilerEngine; + using mlir::concretelang::TensorLambdaArgument; + + JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%x) : (tensor<3x4x2x!FHE.eint<7>>) -> (!FHE.eint<7>) + return %0 : !FHE.eint<7> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected = 96; + + ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + TensorLambdaArgument> xArg(xRef, {3, 4, 2}); + + Expected result = lambda.operator()({&xArg}); + ASSERT_EXPECTED_SUCCESS(result); + + ASSERT_EQ(*result, expected); +} + class TiledMatMulParametric : public ::testing::TestWithParam> {};