diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td index 9a4ad05e9..244b2f523 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.td @@ -148,6 +148,71 @@ def SubIntEintOp : FHE_Op<"sub_int_eint"> { let hasVerifier = 1; } +def SubEintIntOp : FHE_Op<"sub_eint_int"> { + + let summary = "Substract a clear integer from an encrypted integer"; + + let description = [{ + Substract a clear integer from an encrypted integer. + The clear integer must have at most one more bit than the encrypted integer + and the result must have the same width than the encrypted integer. + + Example: + ```mlir + // ok + "FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + + // error + "FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i4) -> !FHE.eint<2> + "FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<3> + ``` + }]; + + let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; + + let hasVerifier = 1; + + let hasFolder = 1; +} + +def SubEintOp : FHE_Op<"sub_eint"> { + + let summary = "Subtracts two encrypted integers"; + + let description = [{ + Subtracts two encrypted integers + The encrypted integers and the result must have the same width. + + Example: + ```mlir + // ok + "FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + + // error + "FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>) + "FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>) + ``` + }]; + + let arguments = (ins EncryptedIntegerType:$a, EncryptedIntegerType:$b); + let results = (outs EncryptedIntegerType); + + let builders = [ + OpBuilder<(ins "Value":$a, "Value":$b), [{ + build($_builder, $_state, a.getType(), a, b); + }]> + ]; + + let hasVerifier = 1; +} + def NegEintOp : FHE_Op<"neg_eint"> { let summary = "Negates an encrypted integer"; @@ -224,7 +289,7 @@ def ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table"> { "FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<2>) "FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>) "FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<3>, tensor<4xi64>) -> (!FHE.eint<2>) - + // error "FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<8xi64>) -> (!FHE.eint<2>) ``` diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 6ed23634c..271b1313e 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -181,6 +181,114 @@ def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tensor ]; } +def SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> { + let summary = "Returns a tensor that contains the substraction of a tensor of clear integers from a tensor of encrypted integers."; + + let description = [{ + Performs a substraction following the broadcasting rules between a tensor of clear integers from a tensor of encrypted integers. + The width of the clear integers must be less than or equals to the witdh of encrypted integers. + + Examples: + ```mlir + // Returns the term to term substraction of `%a0` with `%a1` + "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>> + + // Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched. + "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<1x4x4x!FHE.eint<4>>, tensor<4x1x4xi5>) -> tensor<4x4x4x!FHE.eint<4>> + + // Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. + // + // [1,2,3] [1] [0,2,3] + // [4,5,6] - [2] = [2,3,4] + // [7,8,9] [3] [4,5,6] + // + // The dimension #1 of operand #2 is stretched as it is equals to 1. + "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x1x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>> + + // Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. + // + // [1,2,3] [0,0,0] + // [4,5,6] - [1,2,3] = [3,3,3] + // [7,8,9] [6,6,6] + // + // The dimension #2 of operand #2 is stretched as it is equals to 1. + "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<1x3x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>> + + // Same behavior than the previous one, but as the dimension #2 is missing of operand #2. + "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>> + + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{ + build($_builder, $_state, lhs.getType(), lhs, rhs); + }]> + + ]; + + let hasFolder = 1; +} + +def SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, TensorBinaryEint]> { + let summary = "Returns a tensor that contains the subtraction of two tensor of encrypted integers."; + + let description = [{ + Performs an subtraction follwing the broadcasting rules between two tensors of encrypted integers. + The width of the encrypted integers must be equal. + + Examples: + ```mlir + // Returns the term to term subtraction of `%a0` with `%a1` + "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> + + // Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched. + "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>> + + // Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers. + // + // [1,2,3] [1] [0,2,3] + // [4,5,6] - [2] = [2,3,4] + // [7,8,9] [3] [4,5,6] + // + // The dimension #1 of operand #2 is stretched as it is equals to 1. + "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> + + // Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers. + // + // [1,2,3] [0,0,0] + // [4,5,6] - [1,2,3] = [3,3,3] + // [7,8,9] [6,6,6] + // + // The dimension #2 of operand #2 is stretched as it is equals to 1. + "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> + + // Same behavior than the previous one, but as the dimension #2 of operand #2 is missing. + "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> + ``` + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$lhs, + Type.predicate, HasStaticShapePred]>>:$rhs + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let builders = [ + OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{ + build($_builder, $_state, lhs.getType(), lhs, rhs); + }]> + ]; +} + def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> { let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers."; diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 895da78aa..f7e4530d7 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1613,6 +1613,14 @@ void FHETensorOpsToLinalg::runOnOperation() { FHELinalgOpToLinalgGeneric>( &getContext()); + patterns.insert< + FHELinalgOpToLinalgGeneric>( + &getContext()); + patterns.insert< + FHELinalgOpToLinalgGeneric>( + &getContext()); patterns.insert< FHELinalgOpToLinalgGeneric>( diff --git a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp index db168be7b..f6cc093e9 100644 --- a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp +++ b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp @@ -115,6 +115,85 @@ struct ApplyLookupTableEintOpPattern }; }; +// This rewrite pattern transforms any instance of `FHE.sub_eint_int` +// operators to a negation and an addition. +struct SubEintIntOpPattern : public mlir::OpRewritePattern { + SubEintIntOpPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(FHE::SubEintIntOp op, + mlir::PatternRewriter &rewriter) const override { + mlir::Location location = op.getLoc(); + + mlir::Value lhs = op.getOperand(0); + mlir::Value rhs = op.getOperand(1); + + mlir::Type rhsType = rhs.getType(); + mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(rhsType, -1); + mlir::Value minusOne = + rewriter.create(location, minusOneAttr) + .getResult(); + + mlir::Value negative = + rewriter.create(location, rhs, minusOne) + .getResult(); + + FHEToTFHETypeConverter converter; + auto resultTy = converter.convertType(op.getType()); + + auto addition = + rewriter.create(location, resultTy, lhs, negative); + + mlir::concretelang::convertOperandAndResultTypes( + rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) { + return converter.convertType(t); + }); + + rewriter.replaceOp(op, {addition.getResult()}); + return mlir::success(); + }; +}; + +// This rewrite pattern transforms any instance of `FHE.sub_eint` +// operators to a negation and an addition. +struct SubEintOpPattern : public mlir::OpRewritePattern { + SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(FHE::SubEintOp op, + mlir::PatternRewriter &rewriter) const override { + mlir::Location location = op.getLoc(); + + mlir::Value lhs = op.getOperand(0); + mlir::Value rhs = op.getOperand(1); + + FHEToTFHETypeConverter converter; + + auto rhsTy = converter.convertType(rhs.getType()); + auto negative = rewriter.create(location, rhsTy, rhs); + + mlir::concretelang::convertOperandAndResultTypes( + rewriter, negative, [&](mlir::MLIRContext *, mlir::Type t) { + return converter.convertType(t); + }); + + auto resultTy = converter.convertType(op.getType()); + auto addition = rewriter.create(location, resultTy, lhs, + negative.getResult()); + + mlir::concretelang::convertOperandAndResultTypes( + rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) { + return converter.convertType(t); + }); + + rewriter.replaceOp(op, {addition.getResult()}); + return mlir::success(); + }; +}; + void FHEToTFHEPass::runOnOperation() { auto op = this->getOperation(); @@ -123,6 +202,7 @@ void FHEToTFHEPass::runOnOperation() { // Mark ops from the target dialect as legal operations target.addLegalDialect(); + target.addLegalDialect(); // Make sure that no ops from `FHE` remain after the lowering target.addIllegalDialect(); @@ -155,6 +235,9 @@ void FHEToTFHEPass::runOnOperation() { patterns.getContext(), converter); patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + patterns.add>( &getContext(), converter); diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 2c78a03cb..801c03af9 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -429,6 +429,57 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUAdd(sqNorm, eNorm); } +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `FHE.sub_eint_int` operation. +static llvm::APInt getSqMANP( + mlir::concretelang::FHE::SubEintIntOp op, + llvm::ArrayRef *> operandMANPs) { + mlir::Type iTy = op->getOpOperand(1).get().getType(); + + assert(iTy.isSignlessInteger() && + "Only subtractions with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt sqNorm; + + mlir::arith::ConstantOp cstOp = + llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + + if (cstOp) { + // For constant plaintext operands simply use the constant value + mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); + sqNorm = APIntWidthExtendSqForConstant(attr.getValue()); + } else { + // For dynamic plaintext operands conservatively assume that the integer has + // its maximum possible value + sqNorm = conservativeIntNorm2Sq(iTy); + } + return APIntWidthExtendUAdd(sqNorm, eNorm); +} + +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `FHE.sub_eint` operation. +static llvm::APInt getSqMANP( + mlir::concretelang::FHE::SubEintOp op, + llvm::ArrayRef *> operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + + llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue(); + + return APIntWidthExtendUAdd(a, b); +} + // Calculates the squared Minimal Arithmetic Noise Padding of a dot operation // that is equivalent to an `FHE.neg_eint` operation. static llvm::APInt getSqMANP( @@ -575,6 +626,58 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUAdd(sqNorm, eNorm); } +static llvm::APInt getSqMANP( + mlir::concretelang::FHELinalg::SubEintIntOp op, + llvm::ArrayRef *> operandMANPs) { + + mlir::RankedTensorType op1Ty = + op->getOpOperand(1).get().getType().cast(); + + mlir::Type iTy = op1Ty.getElementType(); + + assert(iTy.isSignlessInteger() && + "Only subtractions with signless integers are currently allowed"); + + assert( + operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt sqNorm; + + mlir::arith::ConstantOp cstOp = + llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + mlir::DenseIntElementsAttr denseVals = + cstOp ? cstOp->getAttrOfType("value") + : nullptr; + + if (denseVals) { + sqNorm = maxIntNorm2Sq(denseVals); + } else { + // For dynamic plaintext operands conservatively assume that the integer has + // its maximum possible value + sqNorm = conservativeIntNorm2Sq(iTy); + } + return APIntWidthExtendUAdd(sqNorm, eNorm); +} + +static llvm::APInt getSqMANP( + mlir::concretelang::FHELinalg::SubEintOp op, + llvm::ArrayRef *> operandMANPs) { + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + operandMANPs[1]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + + llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue(); + llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue(); + + return APIntWidthExtendUAdd(a, b); +} + // Calculates the squared Minimal Arithmetic Noise Padding of a dot operation // that is equivalent to an `FHELinalg.neg_eint` operation. static llvm::APInt getSqMANP( @@ -1192,6 +1295,12 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (auto subIntEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(subIntEintOp, operands); + } else if (auto subEintIntOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(subEintIntOp, operands); + } else if (auto subEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(subEintOp, operands); } else if (auto negEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(negEintOp, operands); @@ -1219,6 +1328,14 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(subIntEintOp, operands); + } else if (auto subEintIntOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(subEintIntOp, operands); + } else if (auto subEintOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(subEintOp, operands); } else if (auto negEintOp = llvm::dyn_cast( op)) { diff --git a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp index cbf49aa66..033961928 100644 --- a/compiler/lib/Dialect/FHE/IR/FHEOps.cpp +++ b/compiler/lib/Dialect/FHE/IR/FHEOps.cpp @@ -89,6 +89,35 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::Operation &op, return ::mlir::success(); } +::mlir::LogicalResult SubEintIntOp::verify() { + auto a = this->a().getType().cast(); + auto b = this->b().getType().cast(); + auto out = this->getResult().getType().cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, + out)) { + return ::mlir::failure(); + } + if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(), + a, b)) { + return ::mlir::failure(); + } + return ::mlir::success(); +} + +::mlir::LogicalResult SubEintOp::verify() { + auto a = this->a().getType().cast(); + auto b = this->b().getType().cast(); + auto out = this->getResult().getType().cast(); + if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a, + out)) { + return ::mlir::failure(); + } + if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) { + return ::mlir::failure(); + } + return ::mlir::success(); +} + ::mlir::LogicalResult NegEintOp::verify() { auto a = this->a().getType().cast(); auto out = this->getResult().getType().cast(); @@ -147,6 +176,19 @@ OpFoldResult AddEintIntOp::fold(ArrayRef operands) { return nullptr; } +// Avoid subtraction with constant 0 +OpFoldResult SubEintIntOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto toSub = operands[1].dyn_cast_or_null(); + if (toSub != nullptr) { + auto intToSub = toSub.getInt(); + if (intToSub == 0) { + return getOperand(0); + } + } + return nullptr; +} + // Avoid multiplication with constant 1 OpFoldResult MulEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index 89add2677..f4996cb90 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -1731,6 +1731,20 @@ OpFoldResult AddEintIntOp::fold(ArrayRef operands) { return getOperand(0); } +// Avoid subtraction with constant tensor of 0s +OpFoldResult SubEintIntOp::fold(ArrayRef operands) { + assert(operands.size() == 2); + auto toSub = operands[1].dyn_cast_or_null(); + if (toSub == nullptr) + return nullptr; + for (auto it = toSub.begin(); it != toSub.end(); it++) { + if (*it != 0) { + return nullptr; + } + } + return getOperand(0); +} + // Avoid multiplication with constant tensor of 1s OpFoldResult MulEintIntOp::fold(ArrayRef operands) { assert(operands.size() == 2); diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir index b5dc0d1a1..00c8653f3 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir @@ -100,6 +100,60 @@ func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> // ----- +func @single_cst_sub_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> +{ + %cst = arith.constant 3 : i3 + + // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + %0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + + return %0 : !FHE.eint<2> +} + +// ----- + +func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2> +{ + %cst = arith.constant -3 : i3 + + // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + %0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + + return %0 : !FHE.eint<2> +} + +// ----- + +func @single_dyn_sub_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> +{ + // sqrt(1 + (2^2-1)^2) = 3.16 + // CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + %0 = "FHE.sub_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + + return %0 : !FHE.eint<2> +} + +// ----- + +func @chain_sub_eint(%e0: !FHE.eint<2>, %e1: !FHE.eint<2>, %e2: !FHE.eint<2>, %e3: !FHE.eint<2>, %e4: !FHE.eint<2>) -> !FHE.eint<2> +{ + // CHECK: %[[V0:.*]] = "FHE.sub_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + %0 = "FHE.sub_eint"(%e0, %e1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + + // CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%[[V0]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + %1 = "FHE.sub_eint"(%0, %e2) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + + // CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint"(%[[V1]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + %2 = "FHE.sub_eint"(%1, %e3) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + + // CHECK-NEXT: %[[V3:.*]] = "FHE.sub_eint"(%[[V2]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + %3 = "FHE.sub_eint"(%2, %e4) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + + return %3 : !FHE.eint<2> +} + +// ----- + func @single_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK: %[[ret:.*]] = "FHE.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>) -> !FHE.eint<2> @@ -147,7 +201,7 @@ func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> // ----- func @single_apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.apply_lookup_table"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> + // CHECK: %[[ret:.*]] = "FHE.apply_lookup_table"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> %1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2> return %1: !FHE.eint<2> } diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir index e157b2513..614a6f064 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir @@ -70,6 +70,41 @@ func @single_cst_sub_int_eint_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> t // ----- +func @single_cst_sub_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> +{ + %cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> + + // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + %0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + + return %0 : tensor<8x!FHE.eint<2>> +} + +// ----- + +func @single_cst_sub_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> +{ + %cst1 = arith.constant 1 : i3 + %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> + + // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + %0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + + return %0 : tensor<8x!FHE.eint<2>> +} + +// ----- + +func @single_sub_eint(%e0: tensor<8x!FHE.eint<2>>, %e1: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> +{ + // CHECK: %[[ret:.*]] = "FHELinalg.sub_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + %0 = "FHELinalg.sub_eint"(%e0, %e1) : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + + return %0 : tensor<8x!FHE.eint<2>> +} + +// ----- + func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { // CHECK: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> diff --git a/compiler/tests/Dialect/FHE/FHE/folding.mlir b/compiler/tests/Dialect/FHE/FHE/folding.mlir index 796abb71b..b8914541d 100644 --- a/compiler/tests/Dialect/FHE/FHE/folding.mlir +++ b/compiler/tests/Dialect/FHE/FHE/folding.mlir @@ -9,6 +9,15 @@ func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK-LABEL: func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> +func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: return %arg0 : !FHE.eint<2> + + %0 = arith.constant 0 : i3 + %1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + // CHECK-LABEL: func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: return %arg0 : !FHE.eint<2> diff --git a/compiler/tests/Dialect/FHE/FHE/ops.mlir b/compiler/tests/Dialect/FHE/FHE/ops.mlir index 9f7770720..20878544c 100644 --- a/compiler/tests/Dialect/FHE/FHE/ops.mlir +++ b/compiler/tests/Dialect/FHE/FHE/ops.mlir @@ -49,6 +49,26 @@ func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { return %1: !FHE.eint<2> } +// CHECK-LABEL: func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> +func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3 + // CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint_int"(%arg0, %[[V1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // CHECK-NEXT: return %[[V2]] : !FHE.eint<2> + + %0 = arith.constant 1 : i3 + %1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + +// CHECK-LABEL: func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> +func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { + // CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2> + // CHECK-NEXT: return %[[V1]] : !FHE.eint<2> + + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>) + return %1: !FHE.eint<2> +} + // CHECK-LABEL: func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { // CHECK-NEXT: %[[V1:.*]] = "FHE.neg_eint"(%arg0) : (!FHE.eint<2>) -> !FHE.eint<2> diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir index ff600ddb7..e6dd5494f 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/folding.mlir @@ -27,6 +27,33 @@ func @add_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FH return %1: tensor<4x3x!FHE.eint<2>> } +// CHECK: func @sub_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3> + %1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// CHECK: func @sub_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %a1 = arith.constant dense<[0]> : tensor<1xi3> + %1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// CHECK: func @sub_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { +// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> { + %a1 = arith.constant dense<[[0]]> : tensor<1x1xi3> + %1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>> + return %1: tensor<4x3x!FHE.eint<2>> +} + // CHECK: func @mul_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { // CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>> // CHECK-NEXT: } diff --git a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir index 3a08ad05b..0ad5241f3 100644 --- a/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir +++ b/compiler/tests/Dialect/FHELinalg/FHELinalg/ops.mlir @@ -110,7 +110,7 @@ func @add_eint_broadcast_2(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<3x4x!FHE.ein ///////////////////////////////////////////////// -// FHELinalg.sub_eint_int +// FHELinalg.sub_int_eint ///////////////////////////////////////////////// // 1D tensor @@ -164,6 +164,116 @@ func @sub_int_eint_broadcast_2(%a0: tensor<3x4xi3>, %a1: tensor<4x!FHE.eint<2>>) } +///////////////////////////////////////////////// +// FHELinalg.sub_eint_int +///////////////////////////////////////////////// + +// 1D tensor +// CHECK: func @sub_eint_int_1D(%[[a0:.*]]: tensor<4xi3>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_int_1D(%a0: tensor<4xi3>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// 2D tensor +// CHECK: func @sub_eint_int_2D(%[[a0:.*]]: tensor<2x4xi3>, %[[a1:.*]]: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<2x4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_int_2D(%a0: tensor<2x4xi3>, %a1: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> { + %1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!FHE.eint<2>> + return %1: tensor<2x4x!FHE.eint<2>> +} + +// 10D tensor +// CHECK: func @sub_eint_int_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10xi3>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_int_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10xi3>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> { + %1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> + return %1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +} + +// Broadcasting with tensor with dimensions equals to one +// CHECK: func @sub_eint_int_broadcast_1(%[[a0:.*]]: tensor<3x4x1xi3>, %[[a1:.*]]: tensor<1x4x5x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_int_broadcast_1(%a0: tensor<3x4x1xi3>, %a1: tensor<1x4x5x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> { + %1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!FHE.eint<2>> + return %1: tensor<3x4x5x!FHE.eint<2>> +} + +// Broadcasting with a tensor less dimensions of another +// CHECK: func @sub_eint_int_broadcast_2(%[[a0:.*]]: tensor<3x4xi3>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<4x!FHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_int_broadcast_2(%a0: tensor<3x4xi3>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> { + %1 ="FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!FHE.eint<2>> + return %1: tensor<3x4x!FHE.eint<2>> +} + + +///////////////////////////////////////////////// +// FHELinalg.sub_eint +///////////////////////////////////////////////// + +// 1D tensor +// CHECK: func @sub_eint_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_1D(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> { + %1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> + return %1: tensor<4x!FHE.eint<2>> +} + +// 2D tensor +// CHECK: func @sub_eint_2D(%[[a0:.*]]: tensor<2x4x!FHE.eint<2>>, %[[a1:.*]]: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<2x4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_2D(%a0: tensor<2x4x!FHE.eint<2>>, %a1: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> { + %1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> + return %1: tensor<2x4x!FHE.eint<2>> +} + +// 10D tensor +// CHECK: func @sub_eint_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> { + %1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> + return %1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> +} + +// Broadcasting with tensor with dimensions equals to one +// CHECK: func @sub_eint_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!FHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_broadcast_1(%a0: tensor<1x4x5x!FHE.eint<2>>, %a1: tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> { + %1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> + return %1: tensor<3x4x5x!FHE.eint<2>> +} + +// Broadcasting with a tensor less dimensions of another +// CHECK: func @sub_eint_broadcast_2(%[[a0:.*]]: tensor<4x!FHE.eint<2>>, %[[a1:.*]]: tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> { +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<4x!FHE.eint<2>>, tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> +// CHECK-NEXT: return %[[V0]] : tensor<3x4x!FHE.eint<2>> +// CHECK-NEXT: } +func @sub_eint_broadcast_2(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> { + %1 ="FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> + return %1: tensor<3x4x!FHE.eint<2>> +} + + ///////////////////////////////////////////////// // FHELinalg.neg_eint ///////////////////////////////////////////////// diff --git a/compiler/tests/fixture/end_to_end_fhe.yaml b/compiler/tests/fixture/end_to_end_fhe.yaml index f4785094e..b2ed17e9f 100644 --- a/compiler/tests/fixture/end_to_end_fhe.yaml +++ b/compiler/tests/fixture/end_to_end_fhe.yaml @@ -63,7 +63,7 @@ tests: outputs: - scalar: 3 --- -description: sub_eint_int_cst +description: sub_int_eint_cst program: | func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> { %0 = arith.constant 7 : i3 @@ -88,6 +88,69 @@ tests: outputs: - scalar: 3 --- +description: sub_eint_int_cst +program: | + func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = arith.constant 7 : i6 + %1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<5>, i6) -> (!FHE.eint<5>) + return %1: !FHE.eint<5> + } +tests: + - inputs: + - scalar: 10 + outputs: + - scalar: 3 + - inputs: + - scalar: 7 + outputs: + - scalar: 0 +--- +description: sub_eint_int_arg +program: | + func @main(%arg0: !FHE.eint<4>, %arg1: i5) -> !FHE.eint<4> { + %1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<4>, i5) -> (!FHE.eint<4>) + return %1: !FHE.eint<4> + } +tests: + - inputs: + - scalar: 2 + - scalar: 2 + outputs: + - scalar: 0 + - inputs: + - scalar: 3 + - scalar: 1 + outputs: + - scalar: 2 + - inputs: + - scalar: 7 + - scalar: 4 + outputs: + - scalar: 3 +--- +description: sub_eint +program: | + func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> { + %1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<4>, !FHE.eint<4>) -> (!FHE.eint<4>) + return %1: !FHE.eint<4> + } +tests: + - inputs: + - scalar: 2 + - scalar: 2 + outputs: + - scalar: 0 + - inputs: + - scalar: 3 + - scalar: 1 + outputs: + - scalar: 2 + - inputs: + - scalar: 7 + - scalar: 4 + outputs: + - scalar: 3 +--- description: sub_int_eint_arg program: | func @main(%arg0: i3, %arg1: !FHE.eint<2>) -> !FHE.eint<2> { diff --git a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc index 9120da9dd..cf6300486 100644 --- a/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_fhelinalg.cc @@ -791,6 +791,447 @@ TEST(End2EndJit_FHELinalg, sub_int_eint_matrix_line_missing_dim) { } } +/////////////////////////////////////////////////////////////////////////////// +// FHELinalg sub_eint_int /////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> { + %res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>> + return %res : tensor<4x!FHE.eint<4>> + } +)XXX"); + std::vector a0{31, 6, 2, 3}; + std::vector a1{32, 9, 12, 9}; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(a0); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(a1); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)4); + + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ((*res)[i], (uint64_t)a1[i] - a0[i]); + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term_broadcast) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<4x1x4xi8>, %a1: tensor<1x4x4x!FHE.eint<7>>) -> tensor<4x4x4x!FHE.eint<7>> { + %res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x4x4x!FHE.eint<7>>, tensor<4x1x4xi8>) -> tensor<4x4x4x!FHE.eint<7>> + return %res : tensor<4x4x4x!FHE.eint<7>> + } +)XXX"); + const uint8_t a0[4][1][4]{ + {{1, 2, 3, 4}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}, + {{13, 14, 15, 16}}, + }; + const uint8_t a1[1][4][4]{ + { + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10, 11, 12}, + {13, 14, 15, 16}, + }, + }; + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 4 * 1 * 4), {4, 1, 4}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a1, 1 * 4 * 4), {1, 4, 4}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)4 * 4 * 4); + + for (size_t i = 0; i < 4; i++) { + for (size_t j = 0; j < 4; j++) { + for (size_t k = 0; k < 4; k++) { + uint8_t expected = a1[i][0][k] - a0[0][j][k]; + EXPECT_EQ((*res)[i * 16 + j * 4 + k], (uint64_t)expected); + } + } + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_column) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x1xi5>) -> + tensor<3x3x!FHE.eint<4>> { + %res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!FHE.eint<4>> + return %res : tensor<3x3x!FHE.eint<4>> + } +)XXX"); + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[3][1]{ + {1}, + {2}, + {3}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 3 * 3), {3, 3}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a1, 3 * 1), {3, 1}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[i][0]); + } + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_line) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<1x3xi5>) -> + tensor<3x3x!FHE.eint<4>> { + %res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!FHE.eint<4>> + return %res : tensor<3x3x!FHE.eint<4>> + } +)XXX"); + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[1][3]{ + {1, 2, 3}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 3 * 3), {3, 3}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a1, 3 * 1), {1, 3}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]); + } + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_line_missing_dim) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3xi5>) -> tensor<3x3x!FHE.eint<4>> { + %res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!FHE.eint<4>> + return %res : tensor<3x3x!FHE.eint<4>> + } +)XXX"); + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[1][3]{ + {1, 2, 3}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 3 * 3), {3, 3}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a1, 3 * 1), {3}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]); + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// FHELinalg add_eint /////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_FHELinalg, sub_eint_term_to_term) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> { + %res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> + return %res : tensor<4x!FHE.eint<6>> + } +)XXX"); + + std::vector a0{31, 6, 12, 9}; + std::vector a1{4, 2, 9, 3}; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(a0); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(a1); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)4); + + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ((*res)[i], (uint64_t)a0[i] - a1[i]) + << "result differ at pos " << i << ", expect " << a0[i] + a1[i] + << " got " << (*res)[i]; + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_term_to_term_broadcast) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<4x1x4x!FHE.eint<5>>, %a1: tensor<1x4x4x!FHE.eint<5>>) -> tensor<4x4x4x!FHE.eint<5>> { + %res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<5>>, tensor<1x4x4x!FHE.eint<5>>) -> + tensor<4x4x4x!FHE.eint<5>> return %res : tensor<4x4x4x!FHE.eint<5>> + } +)XXX"); + uint8_t a0[4][1][4]{ + {{10, 20, 30, 40}}, + {{5, 6, 7, 8}}, + {{9, 10, 11, 12}}, + {{13, 14, 15, 16}}, + }; + uint8_t a1[1][4][4]{ + { + {1, 2, 3, 4}, + {4, 3, 2, 1}, + {3, 1, 4, 2}, + {2, 4, 1, 3}, + }, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::MutableArrayRef((uint8_t *)a0, 4 * 1 * 4), {4, 1, 4}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::MutableArrayRef((uint8_t *)a1, 1 * 4 * 4), {1, 4, 4}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)4 * 4 * 4); + + for (size_t i = 0; i < 4; i++) { + for (size_t j = 0; j < 4; j++) { + for (size_t k = 0; k < 4; k++) { + EXPECT_EQ((*res)[i * 16 + j * 4 + k], + (uint64_t)a0[i][0][k] - a1[0][j][k]) + << "result differ at pos " << i << "," << j << "," << k; + } + } + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_matrix_column) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> { + %res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> + return %res : tensor<3x3x!FHE.eint<4>> + } +)XXX"); + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[3][1]{ + {1}, + {2}, + {3}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 3 * 3), {3, 3}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a0, 3 * 1), {3, 1}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[i][0]); + } + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_matrix_line) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> { + %res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> + return %res : tensor<3x3x!FHE.eint<4>> + } +)XXX"); + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[1][3]{ + {1, 2, 3}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 3 * 3), {3, 3}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a0, 3 * 1), {1, 3}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]); + } + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_matrix_line_missing_dim) { + + checkedJit(lambda, R"XXX( + func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> { + %res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> + return %res : tensor<3x3x!FHE.eint<4>> + } +)XXX"); + const uint8_t a0[3][3]{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + }; + const uint8_t a1[1][3]{ + {1, 2, 3}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 3 * 3), {3, 3}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a0, 3 * 1), {3}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]); + } + } +} + +TEST(End2EndJit_FHELinalg, sub_eint_tensor_dim_equals_1) { + + checkedJit(lambda, R"XXX( + func @main(%arg0: tensor<3x1x2x!FHE.eint<5>>, %arg1: tensor<3x1x2x!FHE.eint<5>>) -> tensor<3x1x2x!FHE.eint<5>> { + %1 = "FHELinalg.sub_eint"(%arg0, %arg1) : (tensor<3x1x2x!FHE.eint<5>>, tensor<3x1x2x!FHE.eint<5>>) -> tensor<3x1x2x!FHE.eint<5>> + return %1 : tensor<3x1x2x!FHE.eint<5>> + } +)XXX"); + const uint8_t a0[3][1][2]{ + {{8, 10}}, + {{12, 14}}, + {{16, 18}}, + }; + const uint8_t a1[3][1][2]{ + {{1, 2}}, + {{4, 5}}, + {{7, 8}}, + }; + + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg0(llvm::ArrayRef((const uint8_t *)a0, 3 * 2), {3, 1, 2}); + mlir::concretelang::TensorLambdaArgument< + mlir::concretelang::IntLambdaArgument> + arg1(llvm::ArrayRef((const uint8_t *)a1, 3 * 2), {3, 1, 2}); + + llvm::Expected> res = + lambda.operator()>({&arg0, &arg1}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (uint64_t)3 * 1 * 2); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 1; j++) { + for (size_t k = 0; k < 2; k++) { + EXPECT_EQ((*res)[i * 2 + j + k], (uint64_t)a0[i][j][k] - a1[i][j][k]); + } + } + } +} + /////////////////////////////////////////////////////////////////////////////// // FHELinalg mul_eint_int /////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////