diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h index 3858813eb..0ffa52b7a 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h @@ -8,6 +8,7 @@ #include "concretelang/Dialect/FHE/IR/FHETypes.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgTypes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" #include diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 73e6b3488..95858873f 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -517,12 +517,21 @@ def ZeroOp : FHELinalg_Op<"zero", []> { } def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> { - let summary = "Returns the sum of all elements of a tensor of encrypted integers."; + let summary = "Returns the sum of elements of a tensor of encrypted integers along specified axes."; let description = [{ - Performs a sum to a tensor of encrypted integers. + Attributes: + + - keep_dims: boolean = false + whether to keep the rank of the tensor after the sum operation + if true, reduced axes will have the size of 1 + + - axes: I64ArrayAttr = [] + list of dimension to perform the sum along + think of it as the dimensions to reduce (see examples below to get an intuition) Examples: + ```mlir // Returns the sum of all elements of `%a0` "FHELinalg.sum"(%a0) : (tensor<3x3x!FHE.eint<4>>) -> !FHE.eint<4> @@ -532,13 +541,64 @@ def SumOp : FHELinalg_Op<"sum", [TensorUnaryEint]> { // ( [7,8,9] ) // ``` + + ```mlir + // Returns the sum of all elements of `%a0` along columns + "FHELinalg.sum"(%a0) { axes = [0] } : (tensor<3x2x!FHE.eint<4>>) -> tensor<2x!FHE.eint<4>> + // + // ( [1,2] ) + // sum ( [3,4] ) = [9, 12] + // ( [5,6] ) + // + ``` + + ```mlir + // Returns the sum of all elements of `%a0` along columns while preserving dimensions + "FHELinalg.sum"(%a0) { axes = [0], keep_dims = true } : (tensor<3x2x!FHE.eint<4>>) -> tensor<1x2x!FHE.eint<4>> + // + // ( [1,2] ) + // sum ( [3,4] ) = [[9, 12]] + // ( [5,6] ) + // + ``` + + ```mlir + // Returns the sum of all elements of `%a0` along rows + "FHELinalg.sum"(%a0) { axes = [1] } : (tensor<3x2x!FHE.eint<4>>) -> tensor<3x!FHE.eint<4>> + // + // ( [1,2] ) + // sum ( [3,4] ) = [3, 7, 11] + // ( [5,6] ) + // + ``` + + ```mlir + // Returns the sum of all elements of `%a0` along rows while preserving dimensions + "FHELinalg.sum"(%a0) { axes = [1], keep_dims = true } : (tensor<3x2x!FHE.eint<4>>) -> tensor<3x1x!FHE.eint<4>> + // + // ( [1,2] ) [3] + // sum ( [3,4] ) = [7] + // ( [5,6] ) [11] + // + ``` }]; let arguments = (ins - Type.predicate, HasStaticShapePred]>>:$tensor + Type.predicate, HasStaticShapePred]>>:$tensor, + DefaultValuedAttr:$axes, + DefaultValuedAttr:$keep_dims ); - let results = (outs EncryptedIntegerType:$out); + let results = (outs + TypeConstraint.predicate, HasStaticShapePred]> + ]>>:$out + ); + + let verifier = [{ + return mlir::concretelang::FHELinalg::verifySum(*this); + }]; } #endif diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index d1342351d..0995b9a0c 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -3,6 +3,8 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. +#include + #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -22,6 +24,13 @@ #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" #include "concretelang/Support/Constants.h" +namespace arith = mlir::arith; +namespace linalg = mlir::linalg; +namespace tensor = mlir::tensor; + +namespace FHE = mlir::concretelang::FHE; +namespace FHELinalg = mlir::concretelang::FHELinalg; + struct DotToLinalgGeneric : public ::mlir::OpRewritePattern { DotToLinalgGeneric(::mlir::MLIRContext *context) @@ -962,8 +971,7 @@ struct FHELinalgZeroToLinalgGenerate }; // This rewrite pattern transforms any instance of operators -// `FHELinalg.zero` to an instance of `linalg.generate` with an -// appropriate region yielding a zero value. +// `FHELinalg.sum` to an instance of `linalg.generic`. // // Example: // @@ -975,14 +983,14 @@ struct FHELinalgZeroToLinalgGenerate // #map0 = affine_map<(i0, i1, ..., iN) -> (i0, i1, ..., iN)> // #map1 = affine_map<(i0, i1, ..., iN) -> (0)> // -// %zero = "FHE.zero"() : () -> !FHE.eint<7> -// %accumulator = tensor.from_elements %zero : tensor<1x!FHE.eint<7>> -// +// %accumulator = "FHELinalg.zero"() : () -> tensor<1x!FHE.eint<7>> // %accumulation = linalg.generic -// { indexing_maps = [#map0, #map1], iterator_types = ["reduction", -// "reduction", ..., "reduction"] } ins(%input : -// tensor>) outs(%accumulator : -// tensor<1x!FHE.eint<7>>) +// { +// indexing_maps = [#map0, #map1], +// iterator_types = ["reduction", "reduction", ..., "reduction"] +// } +// ins(%input : tensor>) +// outs(%accumulator : tensor<1x!FHE.eint<7>>) // { // ^bb0(%a: !FHE.eint<7>, %b: !FHE.eint<7>): // %c = "FHE.add_eint"(%a, %b) : @@ -1003,36 +1011,55 @@ struct SumToLinalgGeneric matchAndRewrite(::mlir::concretelang::FHELinalg::SumOp sumOp, ::mlir::PatternRewriter &rewriter) const override { - namespace arith = mlir::arith; - namespace linalg = mlir::linalg; - namespace tensor = mlir::tensor; - - namespace FHE = mlir::concretelang::FHE; - mlir::Location location = sumOp.getLoc(); mlir::Value input = sumOp.getOperand(); mlir::Value output = sumOp.getResult(); - auto inputType = input.getType().dyn_cast_or_null(); - assert(inputType != nullptr); + auto inputType = input.getType().dyn_cast(); + mlir::Type outputType = output.getType(); llvm::ArrayRef inputShape = inputType.getShape(); - size_t inputDimensions = inputShape.size(); + int64_t inputDimensions = inputShape.size(); - mlir::Value zero = - rewriter.create(location, output.getType()) - .getResult(); + bool outputIsTensor = outputType.isa(); - for (size_t i = 0; i < inputDimensions; i++) { - if (inputShape[i] == 0) { - rewriter.replaceOp(sumOp, {zero}); + for (int64_t size : inputShape) { + if (size == 0) { + mlir::Value result; + if (outputIsTensor) { + result = rewriter.create(location, outputType) + .getResult(); + } else { + result = rewriter.create(location, outputType) + .getResult(); + } + rewriter.replaceOp(sumOp, {result}); return mlir::success(); } } + auto axesToDestroy = std::unordered_set{}; + for (mlir::Attribute axisAttribute : sumOp.axes()) { + int64_t axis = axisAttribute.cast().getInt(); + axesToDestroy.insert(axis); + } + if (axesToDestroy.empty()) { + for (int64_t i = 0; i < inputDimensions; i++) { + axesToDestroy.insert(i); + } + } + + mlir::Type accumulatorType = outputType; + if (!outputIsTensor) { + int64_t accumulatorShape[1] = {1}; + accumulatorType = // tensor of shape (1,) + mlir::RankedTensorType::get(accumulatorShape, outputType); + } + mlir::Value accumulator = - rewriter.create(location, zero).getResult(); + rewriter.create(location, accumulatorType) + .getResult(); auto ins = llvm::SmallVector{input}; auto outs = llvm::SmallVector{accumulator}; @@ -1040,15 +1067,30 @@ struct SumToLinalgGeneric mlir::AffineMap inputMap = mlir::AffineMap::getMultiDimIdentityMap( inputDimensions, this->getContext()); + auto outputAffineExpressions = llvm::SmallVector{}; + if (outputIsTensor) { + for (int64_t i = 0; i < inputDimensions; i++) { + bool ithAxisIsDestroyed = axesToDestroy.find(i) != axesToDestroy.end(); + if (!ithAxisIsDestroyed) { + outputAffineExpressions.push_back(rewriter.getAffineDimExpr(i)); + } else if (sumOp.keep_dims()) { + outputAffineExpressions.push_back(rewriter.getAffineConstantExpr(0)); + } + } + } else { + outputAffineExpressions.push_back(rewriter.getAffineConstantExpr(0)); + } + mlir::AffineMap outputMap = mlir::AffineMap::get( - inputDimensions, 0, {rewriter.getAffineConstantExpr(0)}, - rewriter.getContext()); + inputDimensions, 0, outputAffineExpressions, rewriter.getContext()); auto maps = llvm::SmallVector{inputMap, outputMap}; - auto iteratorTypes = llvm::SmallVector{}; - for (size_t i = 0; i < inputDimensions; i++) { - iteratorTypes.push_back("reduction"); + auto iteratorTypes = llvm::SmallVector( + inputDimensions, mlir::getParallelIteratorTypeName()); + + for (int64_t axis : axesToDestroy) { + iteratorTypes[axis] = mlir::getReductionIteratorTypeName(); } auto regionBuilder = [&](mlir::OpBuilder &nestedBuilder, @@ -1056,29 +1098,30 @@ struct SumToLinalgGeneric mlir::ValueRange blockArgs) { mlir::Value lhs = blockArgs[0]; mlir::Value rhs = blockArgs[1]; - mlir::Value addition = nestedBuilder.create(location, lhs, rhs).getResult(); nestedBuilder.create(location, addition); }; - auto resultTypes = llvm::SmallVector{accumulator.getType()}; + auto resultTypes = llvm::SmallVector{accumulatorType}; mlir::Value accumulation = rewriter .create(location, resultTypes, ins, outs, maps, iteratorTypes, regionBuilder) .getResult(0); - mlir::Value index = - rewriter.create(location, 0).getResult(); - auto indices = llvm::SmallVector{index}; + mlir::Value result = accumulation; + if (!outputIsTensor) { + auto indices = llvm::SmallVector{ + rewriter.create(location, 0).getResult(), + }; + result = + rewriter.create(location, accumulation, indices) + .getResult(); + } - mlir::Value result = - rewriter.create(location, accumulation, indices) - .getResult(); rewriter.replaceOp(sumOp, {result}); - return mlir::success(); }; }; diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 043d9b091..8331cd716 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -841,23 +841,43 @@ static llvm::APInt getSqMANP( mlir::concretelang::FHELinalg::SumOp op, llvm::ArrayRef *> operandMANPs) { - auto type = op->getOperand(0).getType().dyn_cast_or_null(); + auto inputType = op.getOperand().getType().dyn_cast(); - uint64_t numberOfElements = type.getNumElements(); - if (numberOfElements == 0) { + uint64_t numberOfElementsInTheInput = inputType.getNumElements(); + if (numberOfElementsInTheInput == 0) { return llvm::APInt{1, 1, false}; } + uint64_t numberOfElementsAddedTogetherInEachOutputCell = 1; + + mlir::ArrayAttr axes = op.axes(); + if (axes.empty()) { + numberOfElementsAddedTogetherInEachOutputCell *= numberOfElementsInTheInput; + } else { + llvm::ArrayRef shape = inputType.getShape(); + for (mlir::Attribute axisAttribute : op.axes()) { + int64_t axis = axisAttribute.cast().getInt(); + numberOfElementsAddedTogetherInEachOutputCell *= shape[axis]; + } + } + + unsigned int noiseMultiplierBits = + ceilLog2(numberOfElementsAddedTogetherInEachOutputCell + 1); + + auto noiseMultiplier = llvm::APInt{ + noiseMultiplierBits, + numberOfElementsAddedTogetherInEachOutputCell, + false, + }; + assert(operandMANPs.size() == 1 && operandMANPs[0]->getValue().getMANP().hasValue() && "Missing squared Minimal Arithmetic Noise Padding for encrypted " "operands"); + llvm::APInt operandMANP = operandMANPs[0]->getValue().getMANP().getValue(); - unsigned int multiplierBits = ceilLog2(numberOfElements + 1); - auto multiplier = llvm::APInt{multiplierBits, numberOfElements, false}; - - return APIntWidthExtendUMul(multiplier, operandMANP); + return APIntWidthExtendUMul(noiseMultiplier, operandMANP); } struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index d9a9eb521..574195c1a 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -3,6 +3,8 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt // for license information. +#include + #include "mlir/IR/TypeUtilities.h" #include "concretelang/Dialect/FHE/IR/FHEOps.h" @@ -393,6 +395,101 @@ verifyApplyMappedLookupTable(ApplyMappedLookupTableEintOp &op) { return ::mlir::success(); } +llvm::SmallVector +verifySumCalculateActualOutputShape(mlir::Type outputType) { + auto actualOutputShape = llvm::SmallVector{}; + if (outputType.isa()) { + auto outputTensorType = outputType.dyn_cast(); + for (int64_t size : outputTensorType.getShape()) { + actualOutputShape.push_back(size); + } + } + return actualOutputShape; +} + +llvm::SmallVector verifySumCalculateExpectedOutputShape( + llvm::ArrayRef inputShape, int64_t inputDimensions, + std::unordered_set &axesToDestroy, bool keepDims) { + + auto expectedOutputShape = llvm::SmallVector{}; + for (int64_t i = 0; i < inputDimensions; i++) { + bool ithAxisIsDestroyed = axesToDestroy.find(i) != axesToDestroy.end(); + if (!ithAxisIsDestroyed) { + expectedOutputShape.push_back(inputShape[i]); + } else if (keepDims) { + expectedOutputShape.push_back(1); + } + } + return expectedOutputShape; +} + +mlir::LogicalResult verifySum(SumOp &op) { + mlir::Value input = op.getOperand(); + mlir::Value output = op.getResult(); + + auto inputType = input.getType().dyn_cast(); + mlir::Type outputType = output.getType(); + + FHE::EncryptedIntegerType inputElementType = + inputType.getElementType().dyn_cast(); + FHE::EncryptedIntegerType outputElementType = + !outputType.isa() + ? outputType.dyn_cast() + : outputType.dyn_cast() + .getElementType() + .dyn_cast(); + + if (!FHE::verifyEncryptedIntegerInputAndResultConsistency( + op, inputElementType, outputElementType)) { + return mlir::failure(); + } + + llvm::ArrayRef inputShape = inputType.getShape(); + int64_t inputDimensions = (int64_t)inputShape.size(); + + mlir::ArrayAttr axes = op.axes(); + bool keepDims = op.keep_dims(); + + auto axesToDestroy = std::unordered_set{}; + for (mlir::Attribute axisAttribute : axes) { + int64_t axis = axisAttribute.cast().getInt(); + + bool axisIsValid = (0 <= axis) && (axis < inputDimensions); + if (!axisIsValid) { + op.emitOpError("has invalid axes attribute"); + return mlir::failure(); + } + + axesToDestroy.insert(axis); + } + if (axesToDestroy.empty()) { + for (int64_t i = 0; i < inputDimensions; i++) { + axesToDestroy.insert(i); + } + } + + auto expectedOutputShape = verifySumCalculateExpectedOutputShape( + inputShape, inputDimensions, axesToDestroy, keepDims); + auto actualOutputShape = verifySumCalculateActualOutputShape(outputType); + + if (expectedOutputShape != actualOutputShape) { + auto stream = op.emitOpError(); + + stream << "does not have the proper output shape of <"; + if (!expectedOutputShape.empty()) { + stream << expectedOutputShape[0]; + for (size_t i = 1; i < expectedOutputShape.size(); i++) { + stream << "x" << expectedOutputShape[i]; + } + } + stream << ">"; + + return mlir::failure(); + } + + return mlir::success(); +} + /// Verify the matmul shapes, the type of tensor elements should be checked by /// something else template mlir::LogicalResult verifyMatmul(MatMulOp &op) { diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum.mlir new file mode 100644 index 000000000..dcc89dfb7 --- /dev/null +++ b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum.mlir @@ -0,0 +1,555 @@ +// RUN: concretecompiler --split-input-file --action=dump-tfhe --passes fhe-tensor-ops-to-linalg %s 2>&1 | FileCheck %s + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x0x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<3x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1] } : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> + return %0 : tensor<3x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<3x1x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x1x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1], keep_dims = true } : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> + return %0 : tensor<3x1x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<3x0x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x0x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [2] } : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> + return %0 : tensor<3x0x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<3x0x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x0x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [2], keep_dims = true } : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> + return %0 : tensor<3x0x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v2:.*]] = tensor.extract %[[v1]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v2:.*]] = tensor.extract %[[v1]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0] } : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + return %0 : tensor<1x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0], keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + return %0 : tensor<1x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v2:.*]] = tensor.extract %[[v1]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, 0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> + return %0 : tensor<1x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<4x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<4x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0] } : (tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %0 : tensor<4x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, d1)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x4x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x4x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<1x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> + return %0 : tensor<1x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<3x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<3x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<3x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1] } : (tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + return %0 : tensor<3x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, 0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<3x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<3x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<3x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> + return %0 : tensor<3x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v2:.*]] = tensor.extract %[[v1]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 1] } : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0, 0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 1], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> + return %0 : tensor<1x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v2:.*]] = tensor.extract %[[v1]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0, 0, 0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x1x1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + return %0 : tensor<1x1x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x2x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> + return %0 : tensor<3x2x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<3x1x2x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "reduction", "parallel"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<3x1x2x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<3x1x2x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<3x1x2x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> + return %0 : tensor<3x1x2x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<4x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<4x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 2] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %0 : tensor<4x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0, d1, 0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x4x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "parallel", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x4x1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x4x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<1x4x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 2], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> + return %0 : tensor<1x4x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[v2:.*]] = tensor.extract %[[v1]][%[[c0]]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v2]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 2] } : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0, 0, 0)> + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = tensor.generate { +// CHECK-NEXT: ^bb0(%[[d0:.*]]: index): +// CHECK-NEXT: %[[yld:.*]] = "FHE.zero"() : () -> !FHE.eint<7> +// CHECK-NEXT: tensor.yield %[[yld]] : !FHE.eint<7> +// CHECK-NEXT: } : tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: %[[v1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v0]] : tensor<1x1x1x!FHE.eint<7>>) { +// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): +// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> +// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> +// CHECK-NEXT: } -> tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v1]] : tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 2], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + return %0 : tensor<1x1x1x!FHE.eint<7>> +} diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_1d.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_1d.mlir deleted file mode 100644 index 2f3d6fc7b..000000000 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_1d.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s - -// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> - -// CHECK: func @sum_1D(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>> -// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction"]} ins(%[[a0]] : tensor<4x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> -// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> -// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>> -// CHECK-NEXT: return %[[v3]] : !FHE.eint<7> -// CHECK-NEXT: } -func @sum_1D(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> - return %0 : !FHE.eint<7> -} diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_2d.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_2d.mlir deleted file mode 100644 index 9c8af2f59..000000000 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_2d.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s - -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (0)> - -// CHECK: func @sum_2D(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>> -// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> -// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> -// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>> -// CHECK-NEXT: return %[[v3]] : !FHE.eint<7> -// CHECK-NEXT: } -func @sum_2D(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> - return %0 : !FHE.eint<7> -} diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_3d.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_3d.mlir deleted file mode 100644 index 94ad08c88..000000000 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_3d.mlir +++ /dev/null @@ -1,21 +0,0 @@ -// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s - -// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (0)> - -// CHECK: func @sum_3D(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: %[[v1:.*]] = tensor.from_elements %[[v0]] : tensor<1x!FHE.eint<7>> -// CHECK-NEXT: %[[v2:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["reduction", "reduction", "reduction"]} ins(%[[a0]] : tensor<3x4x2x!FHE.eint<7>>) outs(%[[v1]] : tensor<1x!FHE.eint<7>>) { -// CHECK-NEXT: ^bb0(%[[aa0:.*]]: !FHE.eint<7>, %[[aa1:.*]]: !FHE.eint<7>): -// CHECK-NEXT: %[[vv0:.*]] = "FHE.add_eint"(%[[aa0]], %[[aa1]]) : (!FHE.eint<7>, !FHE.eint<7>) -> !FHE.eint<7> -// CHECK-NEXT: linalg.yield %[[vv0]] : !FHE.eint<7> -// CHECK-NEXT: } -> tensor<1x!FHE.eint<7>> -// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[v3:.*]] = tensor.extract %[[v2]][%[[c0]]] : tensor<1x!FHE.eint<7>> -// CHECK-NEXT: return %[[v3]] : !FHE.eint<7> -// CHECK-NEXT: } -func @sum_3D(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> - return %0 : !FHE.eint<7> -} diff --git a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_empty.mlir b/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_empty.mlir deleted file mode 100644 index 5071cb970..000000000 --- a/compiler/tests/Conversion/FHELinalgToLinalg/FHELinalgToLinalg/sum_empty.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s - -// CHECK: func @sum_empty(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = "FHE.zero"() : () -> !FHE.eint<7> -// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> -// CHECK-NEXT: } -func @sum_empty(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> - return %0 : !FHE.eint<7> -} diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir index b9789e4b0..a3bdd86d0 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir @@ -394,21 +394,155 @@ func @zero() -> tensor<8x!FHE.eint<2>> // ----- func @sum() -> !FHE.eint<7> { - %0 = "FHELinalg.zero"() : () -> tensor<4x!FHE.eint<7>> - // CHECK: "FHELinalg.sum"(%0) {MANP = 2 : ui{{[0-9]+}}} : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> - %1 = "FHELinalg.sum"(%0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + %0 = "FHELinalg.zero"() : () -> tensor<5x3x4x2x!FHE.eint<7>> - %2 = "FHELinalg.zero"() : () -> tensor<5x!FHE.eint<7>> - // CHECK: "FHELinalg.sum"(%2) {MANP = 3 : ui{{[0-9]+}}} : (tensor<5x!FHE.eint<7>>) -> !FHE.eint<7> - %3 = "FHELinalg.sum"(%2) : (tensor<5x!FHE.eint<7>>) -> !FHE.eint<7> + // CHECK: MANP = 11 : ui{{[0-9]+}} + %1 = "FHELinalg.sum"(%0) : (tensor<5x3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> - %4 = "FHELinalg.zero"() : () -> tensor<9x!FHE.eint<7>> - // CHECK: "FHELinalg.sum"(%4) {MANP = 3 : ui{{[0-9]+}}} : (tensor<9x!FHE.eint<7>>) -> !FHE.eint<7> - %5 = "FHELinalg.sum"(%4) : (tensor<9x!FHE.eint<7>>) -> !FHE.eint<7> + // CHECK: MANP = 3 : ui{{[0-9]+}} + %2 = "FHELinalg.sum"(%0) { axes = [0] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x4x2x!FHE.eint<7>> - %6 = "FHELinalg.zero"() : () -> tensor<10x!FHE.eint<7>> - // CHECK: "FHELinalg.sum"(%6) {MANP = 4 : ui{{[0-9]+}}} : (tensor<10x!FHE.eint<7>>) -> !FHE.eint<7> - %7 = "FHELinalg.sum"(%6) : (tensor<10x!FHE.eint<7>>) -> !FHE.eint<7> + // CHECK: MANP = 2 : ui{{[0-9]+}} + %3 = "FHELinalg.sum"(%0) { axes = [1] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x4x2x!FHE.eint<7>> - return %7 : !FHE.eint<7> + // CHECK: MANP = 2 : ui{{[0-9]+}} + %4 = "FHELinalg.sum"(%0) { axes = [2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %5 = "FHELinalg.sum"(%0) { axes = [3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x4x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %6 = "FHELinalg.sum"(%0) { axes = [0, 1] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<4x2x!FHE.eint<7>> + + // CHECK: MANP = 5 : ui{{[0-9]+}} + %7 = "FHELinalg.sum"(%0) { axes = [0, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %8 = "FHELinalg.sum"(%0) { axes = [0, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %9 = "FHELinalg.sum"(%0) { axes = [1, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x2x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %10 = "FHELinalg.sum"(%0) { axes = [1, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x4x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %11 = "FHELinalg.sum"(%0) { axes = [2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x!FHE.eint<7>> + + // CHECK: MANP = 8 : ui{{[0-9]+}} + %12 = "FHELinalg.sum"(%0) { axes = [0, 1, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + + // CHECK: MANP = 6 : ui{{[0-9]+}} + %13 = "FHELinalg.sum"(%0) { axes = [0, 1, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + + // CHECK: MANP = 7 : ui{{[0-9]+}} + %14 = "FHELinalg.sum"(%0) { axes = [0, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + + // CHECK: MANP = 5 : ui{{[0-9]+}} + %15 = "FHELinalg.sum"(%0) { axes = [1, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x!FHE.eint<7>> + + // CHECK: MANP = 11 : ui{{[0-9]+}} + %16 = "FHELinalg.sum"(%0) { axes = [0, 1, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + + // CHECK: MANP = 11 : ui{{[0-9]+}} + %17 = "FHELinalg.sum"(%0) { keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x1x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %18 = "FHELinalg.sum"(%0) { axes = [0], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x4x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %19 = "FHELinalg.sum"(%0) { axes = [1], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x4x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %20 = "FHELinalg.sum"(%0) { axes = [2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x1x2x!FHE.eint<7>> + + // CHECK: MANP = 2 : ui{{[0-9]+}} + %21 = "FHELinalg.sum"(%0) { axes = [3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x4x1x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %22 = "FHELinalg.sum"(%0) { axes = [0, 1], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x4x2x!FHE.eint<7>> + + // CHECK: MANP = 5 : ui{{[0-9]+}} + %23 = "FHELinalg.sum"(%0) { axes = [0, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x1x2x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %24 = "FHELinalg.sum"(%0) { axes = [0, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x4x1x!FHE.eint<7>> + + // CHECK: MANP = 4 : ui{{[0-9]+}} + %25 = "FHELinalg.sum"(%0) { axes = [1, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x1x2x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %26 = "FHELinalg.sum"(%0) { axes = [1, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x4x1x!FHE.eint<7>> + + // CHECK: MANP = 3 : ui{{[0-9]+}} + %27 = "FHELinalg.sum"(%0) { axes = [2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x3x1x1x!FHE.eint<7>> + + // CHECK: MANP = 8 : ui{{[0-9]+}} + %28 = "FHELinalg.sum"(%0) { axes = [0, 1, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x2x!FHE.eint<7>> + + // CHECK: MANP = 6 : ui{{[0-9]+}} + %29 = "FHELinalg.sum"(%0) { axes = [0, 1, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x4x1x!FHE.eint<7>> + + // CHECK: MANP = 7 : ui{{[0-9]+}} + %30 = "FHELinalg.sum"(%0) { axes = [0, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x3x1x1x!FHE.eint<7>> + + // CHECK: MANP = 5 : ui{{[0-9]+}} + %31 = "FHELinalg.sum"(%0) { axes = [1, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<5x1x1x1x!FHE.eint<7>> + + // CHECK: MANP = 11 : ui{{[0-9]+}} + %32 = "FHELinalg.sum"(%0) { axes = [0, 1, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x1x!FHE.eint<7>> + + // =============================== + + %35 = "FHELinalg.zero"() : () -> tensor<2x0x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %36 = "FHELinalg.sum"(%35) : (tensor<2x0x3x!FHE.eint<7>>) -> !FHE.eint<7> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %37 = "FHELinalg.sum"(%35) { axes = [0] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<0x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %38 = "FHELinalg.sum"(%35) { axes = [1] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %39 = "FHELinalg.sum"(%35) { axes = [2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x0x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %40 = "FHELinalg.sum"(%35) { axes = [0, 1] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %41 = "FHELinalg.sum"(%35) { axes = [0, 2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<0x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %42 = "FHELinalg.sum"(%35) { axes = [1, 2] } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %43 = "FHELinalg.sum"(%35) { axes = [0, 1 ,2] } : (tensor<2x0x3x!FHE.eint<7>>) -> !FHE.eint<7> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %44 = "FHELinalg.sum"(%35) { keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %45 = "FHELinalg.sum"(%35) { axes = [0], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x0x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %46 = "FHELinalg.sum"(%35) { axes = [1], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x1x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %47 = "FHELinalg.sum"(%35) { axes = [2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x0x1x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %48 = "FHELinalg.sum"(%35) { axes = [0, 1], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x3x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %49 = "FHELinalg.sum"(%35) { axes = [0, 2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x0x1x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %50 = "FHELinalg.sum"(%35) { axes = [1, 2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<2x1x1x!FHE.eint<7>> + + // CHECK: MANP = 1 : ui{{[0-9]+}} + %51 = "FHELinalg.sum"(%35) { axes = [0, 1 ,2], keep_dims = true } : (tensor<2x0x3x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + + return %1 : !FHE.eint<7> } diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir index 740f38224..f718d8f1f 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir @@ -367,43 +367,3 @@ func @zero_2D() -> tensor<4x9x!FHE.eint<2>> { %0 = "FHELinalg.zero"() : () -> tensor<4x9x!FHE.eint<2>> return %0 : tensor<4x9x!FHE.eint<2>> } - -///////////////////////////////////////////////// -// FHELinalg.sum -///////////////////////////////////////////////// - -// CHECK: func @sum_empty(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> -// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> -// CHECK-NEXT: } -func @sum_empty(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> - return %0 : !FHE.eint<7> -} - -// CHECK: func @sum_1D(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> -// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> -// CHECK-NEXT: } -func @sum_1D(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> - return %0 : !FHE.eint<7> -} - -// CHECK: func @sum_2D(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> -// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> -// CHECK-NEXT: } -func @sum_2D(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> - return %0 : !FHE.eint<7> -} - -// CHECK: func @sum_3D(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { -// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> -// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> -// CHECK-NEXT: } -func @sum_3D(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> - return %0 : !FHE.eint<7> -} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/sum.invalid.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/sum.invalid.mlir new file mode 100644 index 000000000..f63d8f5f8 --- /dev/null +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/sum.invalid.mlir @@ -0,0 +1,281 @@ +// RUN: concretecompiler --split-input-file --action=roundtrip --verify-diagnostics %s + +// ----- + +func @sum_invalid_bitwidth(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<6> { + // expected-error @+1 {{'FHELinalg.sum' op should have the width of encrypted inputs and result equals}} + %1 = "FHELinalg.sum"(%arg0): (tensor<4x!FHE.eint<7>>) -> !FHE.eint<6> + return %1 : !FHE.eint<6> +} + +// ----- + +func @sum_invalid_axes_1(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + // expected-error @+1 {{'FHELinalg.sum' op has invalid axes attribute}} + %1 = "FHELinalg.sum"(%arg0) { axes = [4] } : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %1 : !FHE.eint<7> +} + +// ----- + +func @sum_invalid_axes_2(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + // expected-error @+1 {{'FHELinalg.sum' op has invalid axes attribute}} + %1 = "FHELinalg.sum"(%arg0) { axes = [-1] } : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %1 : !FHE.eint<7> +} + +// ----- + +func @sum_invalid_shape_01(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <>}} + %1 = "FHELinalg.sum"(%arg0) : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_02(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <3x4x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_03(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x4x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [1] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_04(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x3x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_05(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x3x4>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_06(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <4x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 1] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_07(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <3x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_08(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <3x4>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_09(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [1, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_10(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x4>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [1, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_11(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x3>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_12(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 2] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_13(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <4>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_14(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <3>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_15(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [1, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_16(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 2, 3] } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_17(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x1x1x1>}} + %1 = "FHELinalg.sum"(%arg0) { keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_18(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x3x4x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_19(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x1x4x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [1], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_20(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x3x1x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_21(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x3x4x1>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_22(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x1x4x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 1], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_23(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x3x1x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_24(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x3x4x1>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_25(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x1x1x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [1, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_26(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x1x4x1>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [1, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_27(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x3x1x1>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_28(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x1x1x2>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 2], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_29(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x1x4x1>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_30(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x3x1x1>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_31(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <5x1x1x1>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [1, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} + +// ----- + +func @sum_invalid_shape_32(%arg0: tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> { + // expected-error @+1 {{'FHELinalg.sum' op does not have the proper output shape of <1x1x1x1>}} + %1 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 2, 3], keep_dims = true } : (tensor<5x3x4x2x!FHE.eint<7>>) -> tensor<10x20x!FHE.eint<7>> + return %1 : tensor<10x20x!FHE.eint<7>> +} diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/sum.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/sum.mlir new file mode 100644 index 000000000..4f3592858 --- /dev/null +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/sum.mlir @@ -0,0 +1,287 @@ +// RUN: concretecompiler --split-input-file --action=roundtrip %s 2>&1| FileCheck %s + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x0x4x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x0x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [1]} : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1] } : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> + return %0 : tensor<3x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [1], keep_dims = true} : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x1x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1], keep_dims = true } : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x1x4x!FHE.eint<7>> + return %0 : tensor<3x1x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [2]} : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x0x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [2] } : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x!FHE.eint<7>> + return %0 : tensor<3x0x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [2], keep_dims = true} : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x0x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [2], keep_dims = true } : (tensor<3x0x4x!FHE.eint<7>>) -> tensor<3x0x1x!FHE.eint<7>> + return %0 : tensor<3x0x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0]} : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0] } : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {keep_dims = true} : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + return %0 : tensor<1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0], keep_dims = true} : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0], keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + return %0 : tensor<1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {keep_dims = true} : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> + return %0 : tensor<1x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0]} : (tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0] } : (tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %0 : tensor<4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0], keep_dims = true} : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> + return %0 : tensor<1x4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [1]} : (tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1] } : (tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + return %0 : tensor<3x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [1], keep_dims = true} : (tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> + return %0 : tensor<3x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0, 1]} : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 1] } : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0, 1], keep_dims = true} : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 1], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> + return %0 : tensor<1x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {keep_dims = true} : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + return %0 : tensor<1x1x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [1]} : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x2x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> + return %0 : tensor<3x2x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [1], keep_dims = true} : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<3x1x2x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [1], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> + return %0 : tensor<3x1x2x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0, 2]} : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<4x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 2] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %0 : tensor<4x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0, 2], keep_dims = true} : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x4x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 2], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> + return %0 : tensor<1x4x1x!FHE.eint<7>> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0, 1, 2]} : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> +// CHECK-NEXT: return %[[v0]] : !FHE.eint<7> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 2] } : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +// ----- + +// CHECK: func @main(%[[a0:.*]]: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { +// CHECK-NEXT: %[[v0:.*]] = "FHELinalg.sum"(%[[a0]]) {axes = [0, 1, 2], keep_dims = true} : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: return %[[v0]] : tensor<1x1x1x!FHE.eint<7>> +// CHECK-NEXT: } +func @main(%arg0: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%arg0) { axes = [0, 1, 2], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + return %0 : tensor<1x1x1x!FHE.eint<7>> +} diff --git a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc index 32cdc119c..6c8f17d84 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc @@ -1584,18 +1584,12 @@ func @main() -> tensor<2x2x4x!FHE.eint<6>> { /////////////////////////////////////////////////////////////////////////////// TEST(End2EndJit_FHELinalg, sum_empty) { + namespace concretelang = mlir::concretelang; - using llvm::ArrayRef; - using llvm::Expected; - - using mlir::concretelang::IntLambdaArgument; - using mlir::concretelang::JitCompilerEngine; - using mlir::concretelang::TensorLambdaArgument; - - JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%x: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%x) : (tensor<0x!FHE.eint<7>>) -> (!FHE.eint<7>) + %0 = "FHELinalg.sum"(%x) : (tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> return %0 : !FHE.eint<7> } @@ -1603,28 +1597,23 @@ func @main(%x: tensor<0x!FHE.eint<7>>) -> !FHE.eint<7> { const uint8_t expected = 0; - ArrayRef xRef(nullptr, (size_t)0); - TensorLambdaArgument> xArg(xRef, {0}); + llvm::ArrayRef xRef(nullptr, (size_t)0); + concretelang::TensorLambdaArgument> + xArg(xRef, {0}); - Expected result = lambda.operator()({&xArg}); + llvm::Expected result = lambda.operator()({&xArg}); ASSERT_EXPECTED_SUCCESS(result); ASSERT_EQ(*result, expected); } -TEST(End2EndJit_FHELinalg, sum_1D) { +TEST(End2EndJit_FHELinalg, sum_1D_no_axes) { + namespace concretelang = mlir::concretelang; - using llvm::ArrayRef; - using llvm::Expected; - - using mlir::concretelang::IntLambdaArgument; - using mlir::concretelang::JitCompilerEngine; - using mlir::concretelang::TensorLambdaArgument; - - JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%x: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%x) : (tensor<4x!FHE.eint<7>>) -> (!FHE.eint<7>) + %0 = "FHELinalg.sum"(%x) : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> return %0 : !FHE.eint<7> } @@ -1633,28 +1622,48 @@ func @main(%x: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { const uint8_t x[4]{0, 1, 2, 3}; const uint8_t expected = 6; - ArrayRef xRef((const uint8_t *)x, 4); - TensorLambdaArgument> xArg(xRef, {4}); + llvm::ArrayRef xRef((const uint8_t *)x, 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {4}); - Expected result = lambda.operator()({&xArg}); + llvm::Expected result = lambda.operator()({&xArg}); ASSERT_EXPECTED_SUCCESS(result); ASSERT_EQ(*result, expected); } -TEST(End2EndJit_FHELinalg, sum_2D) { +TEST(End2EndJit_FHELinalg, sum_1D_axes_0) { + namespace concretelang = mlir::concretelang; - using llvm::ArrayRef; - using llvm::Expected; + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( - using mlir::concretelang::IntLambdaArgument; - using mlir::concretelang::JitCompilerEngine; - using mlir::concretelang::TensorLambdaArgument; +func @main(%x: tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%x) { axes = [0] } : (tensor<4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} - JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +)XXX"); + + const uint8_t x[4]{0, 1, 2, 3}; + const uint8_t expected = 6; + + llvm::ArrayRef xRef((const uint8_t *)x, 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {4}); + + llvm::Expected result = lambda.operator()({&xArg}); + ASSERT_EXPECTED_SUCCESS(result); + + ASSERT_EQ(*result, expected); +} + +TEST(End2EndJit_FHELinalg, sum_2D_no_axes) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%x: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%x) : (tensor<3x4x!FHE.eint<7>>) -> (!FHE.eint<7>) + %0 = "FHELinalg.sum"(%x) : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> return %0 : !FHE.eint<7> } @@ -1667,28 +1676,134 @@ func @main(%x: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { }; const uint8_t expected = 46; - ArrayRef xRef((const uint8_t *)x, 3 * 4); - TensorLambdaArgument> xArg(xRef, {3, 4}); + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4}); - Expected result = lambda.operator()({&xArg}); + llvm::Expected result = lambda.operator()({&xArg}); ASSERT_EXPECTED_SUCCESS(result); ASSERT_EQ(*result, expected); } -TEST(End2EndJit_FHELinalg, sum_3D) { +TEST(End2EndJit_FHELinalg, sum_2D_axes_0) { + namespace concretelang = mlir::concretelang; - using llvm::ArrayRef; - using llvm::Expected; + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( - using mlir::concretelang::IntLambdaArgument; - using mlir::concretelang::JitCompilerEngine; - using mlir::concretelang::TensorLambdaArgument; +func @main(%x: tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0] } : (tensor<3x4x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %0 : tensor<4x!FHE.eint<7>> +} - JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( +)XXX"); + + const uint8_t x[3][4]{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 0, 1}, + }; + const uint8_t expected[4]{12, 15, 8, 11}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 4); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 4); + + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, sum_2D_axes_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [1] } : (tensor<3x4x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + return %0 : tensor<3x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4]{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 0, 1}, + }; + const uint8_t expected[3]{6, 22, 18}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 3); + + for (size_t i = 0; i < 3; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, sum_2D_axes_0_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%x) { axes = [0, 1] } : (tensor<3x4x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +)XXX"); + + const uint8_t x[3][4]{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 0, 1}, + }; + const uint8_t expected = 46; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4}); + + llvm::Expected result = lambda.operator()({&xArg}); + ASSERT_EXPECTED_SUCCESS(result); + + ASSERT_EQ(*result, expected); +} + +TEST(End2EndJit_FHELinalg, sum_3D_no_axes) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { - %0 = "FHELinalg.sum"(%x) : (tensor<3x4x2x!FHE.eint<7>>) -> (!FHE.eint<7>) + %0 = "FHELinalg.sum"(%x) : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> return %0 : !FHE.eint<7> } @@ -1716,15 +1831,1234 @@ func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { }; const uint8_t expected = 96; - ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); - TensorLambdaArgument> xArg(xRef, {3, 4, 2}); + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); - Expected result = lambda.operator()({&xArg}); + llvm::Expected result = lambda.operator()({&xArg}); ASSERT_EXPECTED_SUCCESS(result); ASSERT_EQ(*result, expected); } +TEST(End2EndJit_FHELinalg, sum_3D_axes_0) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x2x!FHE.eint<7>> + return %0 : tensor<4x2x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[4][2]{ + {14, 17}, + {10, 13}, + {6, 9}, + {12, 15}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 4); + ASSERT_EQ(res.getDimensions().at(1), 2); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 8); + + for (size_t i = 0; i < 4; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ(res.getValue()[(i * 2) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_axes_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [1] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x2x!FHE.eint<7>> + return %0 : tensor<3x2x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[3][2]{ + {12, 16}, + {14, 18}, + {16, 20}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EQ(res.getDimensions().at(1), 2); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 6); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 2; j++) { + EXPECT_EQ(res.getValue()[(i * 2) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_axes_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [2] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x4x!FHE.eint<7>> + return %0 : tensor<3x4x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[3][4]{ + {1, 5, 9, 13}, + {17, 1, 5, 9}, + {13, 17, 1, 5}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EQ(res.getDimensions().at(1), 4); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 12); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 4; j++) { + EXPECT_EQ(res.getValue()[(i * 4) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_axes_0_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0, 1] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<2x!FHE.eint<7>> + return %0 : tensor<2x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[2]{42, 54}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 2); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 2); + + for (size_t i = 0; i < 2; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_axes_1_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [1, 2] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x!FHE.eint<7>> + return %0 : tensor<3x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[3]{28, 32, 36}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 3); + + for (size_t i = 0; i < 3; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_axes_0_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0, 2] } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<4x!FHE.eint<7>> + return %0 : tensor<4x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[4]{31, 23, 15, 27}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 4); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 4); + + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_axes_0_1_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> { + %0 = "FHELinalg.sum"(%x) { axes = [0, 1, 2] } : (tensor<3x4x2x!FHE.eint<7>>) -> !FHE.eint<7> + return %0 : !FHE.eint<7> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected = 96; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected result = lambda.operator()({&xArg}); + ASSERT_EXPECTED_SUCCESS(result); + + ASSERT_EQ(*result, expected); +} + +TEST(End2EndJit_FHELinalg, sum_keep_dims_empty) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<0x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { keep_dims = true } : (tensor<0x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + return %0 : tensor<1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t expected[1] = {0}; + + llvm::ArrayRef xRef(nullptr, (size_t)0); + concretelang::TensorLambdaArgument> + xArg(xRef, {0}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 1); + + for (size_t i = 0; i < 1; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, sum_1D_keep_dims_no_axes) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + return %0 : tensor<1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[4]{0, 1, 2, 3}; + const uint8_t expected[1] = {6}; + + llvm::ArrayRef xRef((const uint8_t *)x, 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {4}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 1); + + for (size_t i = 0; i < 1; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, sum_1D_keep_dims_axes_0) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0], keep_dims = true } : (tensor<4x!FHE.eint<7>>) -> tensor<1x!FHE.eint<7>> + return %0 : tensor<1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[4]{0, 1, 2, 3}; + const uint8_t expected[1] = {6}; + + llvm::ArrayRef xRef((const uint8_t *)x, 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {4}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)1); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 1); + + for (size_t i = 0; i < 1; i++) { + EXPECT_EQ(res.getValue()[i], expected[i]) << ", at pos(" << i << ")"; + } +} + +TEST(End2EndJit_FHELinalg, sum_2D_keep_dims_no_axes) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> + return %0 : tensor<1x1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4]{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 0, 1}, + }; + const uint8_t expected[1][1] = {{46}}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EQ(res.getDimensions().at(1), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 1); + + for (size_t i = 0; i < 1; i++) { + for (size_t j = 0; j < 1; j++) { + EXPECT_EQ(res.getValue()[(i * 1) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, sum_2D_keep_dims_axes_0) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x4x!FHE.eint<7>> + return %0 : tensor<1x4x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4]{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 0, 1}, + }; + const uint8_t expected[1][4]{ + {12, 15, 8, 11}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EQ(res.getDimensions().at(1), 4); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 4); + + for (size_t i = 0; i < 1; i++) { + for (size_t j = 0; j < 4; j++) { + EXPECT_EQ(res.getValue()[(i * 4) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, sum_2D_keep_dims_axes_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [1], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<3x1x!FHE.eint<7>> + return %0 : tensor<3x1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4]{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 0, 1}, + }; + const uint8_t expected[3][1]{ + {6}, + {22}, + {18}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EQ(res.getDimensions().at(1), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 1; j++) { + EXPECT_EQ(res.getValue()[(i * 1) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, sum_2D_keep_dims_axes_0_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0, 1], keep_dims = true } : (tensor<3x4x!FHE.eint<7>>) -> tensor<1x1x!FHE.eint<7>> + return %0 : tensor<1x1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4]{ + {0, 1, 2, 3}, + {4, 5, 6, 7}, + {8, 9, 0, 1}, + }; + const uint8_t expected[1][1] = {{46}}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)2); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EQ(res.getDimensions().at(1), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 1); + + for (size_t i = 0; i < 1; i++) { + for (size_t j = 0; j < 1; j++) { + EXPECT_EQ(res.getValue()[(i * 1) + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_keep_dims_no_axes) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + return %0 : tensor<1x1x1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[1][1][1] = {{{96}}}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EQ(res.getDimensions().at(1), 1); + ASSERT_EQ(res.getDimensions().at(2), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 1); + + for (size_t i = 0; i < 1; i++) { + for (size_t j = 0; j < 1; j++) { + for (size_t k = 0; k < 1; k++) { + EXPECT_EQ(res.getValue()[(i * 1 * 1) + (j * 1) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_keep_dims_axes_0) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x2x!FHE.eint<7>> + return %0 : tensor<1x4x2x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[1][4][2]{{ + {14, 17}, + {10, 13}, + {6, 9}, + {12, 15}, + }}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EQ(res.getDimensions().at(1), 4); + ASSERT_EQ(res.getDimensions().at(2), 2); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 8); + + for (size_t i = 0; i < 1; i++) { + for (size_t j = 0; j < 4; j++) { + for (size_t k = 0; k < 2; k++) { + EXPECT_EQ(res.getValue()[(i * 4 * 2) + (j * 2) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_keep_dims_axes_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [1], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x2x!FHE.eint<7>> + return %0 : tensor<3x1x2x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[3][1][2]{ + {{12, 16}}, + {{14, 18}}, + {{16, 20}}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EQ(res.getDimensions().at(1), 1); + ASSERT_EQ(res.getDimensions().at(2), 2); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 6); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 1; j++) { + for (size_t k = 0; k < 2; k++) { + EXPECT_EQ(res.getValue()[(i * 1 * 2) + (j * 2) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_keep_dims_axes_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x4x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [2], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x4x1x!FHE.eint<7>> + return %0 : tensor<3x4x1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[3][4][1]{ + {{1}, {5}, {9}, {13}}, + {{17}, {1}, {5}, {9}}, + {{13}, {17}, {1}, {5}}, + }; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EQ(res.getDimensions().at(1), 4); + ASSERT_EQ(res.getDimensions().at(2), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 12); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 4; j++) { + for (size_t k = 0; k < 1; k++) { + EXPECT_EQ(res.getValue()[(i * 4 * 1) + (j * 1) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_keep_dims_axes_0_1) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x2x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0, 1], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x2x!FHE.eint<7>> + return %0 : tensor<1x1x2x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[1][1][2]{{{42, 54}}}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EQ(res.getDimensions().at(1), 1); + ASSERT_EQ(res.getDimensions().at(2), 2); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 2); + + for (size_t i = 0; i < 1; i++) { + for (size_t j = 0; j < 1; j++) { + for (size_t k = 0; k < 2; k++) { + EXPECT_EQ(res.getValue()[(i * 1 * 2) + (j * 2) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_keep_dims_axes_1_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [1, 2], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<3x1x1x!FHE.eint<7>> + return %0 : tensor<3x1x1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[3][1][1]{{{28}}, {{32}}, {{36}}}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 3); + ASSERT_EQ(res.getDimensions().at(1), 1); + ASSERT_EQ(res.getDimensions().at(2), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 1; j++) { + for (size_t k = 0; k < 1; k++) { + EXPECT_EQ(res.getValue()[(i * 1 * 1) + (j * 1) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_keep_dims_axes_0_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0, 2], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x4x1x!FHE.eint<7>> + return %0 : tensor<1x4x1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[1][4][1]{{{31}, {23}, {15}, {27}}}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EQ(res.getDimensions().at(1), 4); + ASSERT_EQ(res.getDimensions().at(2), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 4); + + for (size_t i = 0; i < 1; i++) { + for (size_t j = 0; j < 4; j++) { + for (size_t k = 0; k < 1; k++) { + EXPECT_EQ(res.getValue()[(i * 4 * 1) + (j * 1) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + +TEST(End2EndJit_FHELinalg, sum_3D_keep_dims_axes_0_1_2) { + namespace concretelang = mlir::concretelang; + + concretelang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + +func @main(%x: tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> { + %0 = "FHELinalg.sum"(%x) { axes = [0, 1, 2], keep_dims = true } : (tensor<3x4x2x!FHE.eint<7>>) -> tensor<1x1x1x!FHE.eint<7>> + return %0 : tensor<1x1x1x!FHE.eint<7>> +} + +)XXX"); + + const uint8_t x[3][4][2]{ + { + {0, 1}, + {2, 3}, + {4, 5}, + {6, 7}, + }, + { + {8, 9}, + {0, 1}, + {2, 3}, + {4, 5}, + }, + { + {6, 7}, + {8, 9}, + {0, 1}, + {2, 3}, + }, + }; + const uint8_t expected[1][1][1] = {{{96}}}; + + llvm::ArrayRef xRef((const uint8_t *)x, 3 * 4 * 2); + concretelang::TensorLambdaArgument> + xArg(xRef, {3, 4, 2}); + + llvm::Expected> call = + lambda.operator()>({&xArg}); + ASSERT_EXPECTED_SUCCESS(call); + + concretelang::TensorLambdaArgument> &res = + (*call) + ->cast>>(); + + ASSERT_EQ(res.getDimensions().size(), (size_t)3); + ASSERT_EQ(res.getDimensions().at(0), 1); + ASSERT_EQ(res.getDimensions().at(1), 1); + ASSERT_EQ(res.getDimensions().at(2), 1); + ASSERT_EXPECTED_VALUE(res.getNumElements(), 1); + + for (size_t i = 0; i < 1; i++) { + for (size_t j = 0; j < 1; j++) { + for (size_t k = 0; k < 1; k++) { + EXPECT_EQ(res.getValue()[(i * 1 * 1) + (j * 1) + k], expected[i][j][k]) + << ", at pos(" << i << "," << j << "," << k << ")"; + } + } + } +} + class TiledMatMulParametric : public ::testing::TestWithParam> {};