From ae9a04cd56767dea46b12c1fcb9666fffcbe5b4a Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 20 May 2022 13:39:22 +0200 Subject: [PATCH] fix(compiler): Use the absolute value when computing the square of a constant in MANP analysis The computation of the norm2 should take care of the sign of the constant to compute the square of this constant. --- compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 107 ++++++------- .../tests/Dialect/FHE/FHE/Analysis/MANP.mlir | 9 +- .../Dialect/FHE/FHE/Analysis/MANP_linalg.mlir | 150 +++++++++++------- .../Dialect/FHE/FHE/Analysis/MANP_tensor.mlir | 68 ++------ 4 files changed, 173 insertions(+), 161 deletions(-) diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index f2a913ab9..bed1433bf 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -183,7 +183,7 @@ static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) { // Calculates the square of `i`. The bit width `i` is extended in // order to guarantee that the product fits into the resulting // `APInt`. -static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) { +static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) { // Make sure the required number of bits can be represented by the // `unsigned` argument of `zext`. assert(i.getBitWidth() < std::numeric_limits::max() / 2 && @@ -194,12 +194,22 @@ static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) { return ie * ie; } +// Calculates the square of the absolute value of `i`. +static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) { + // Make sure the required number of bits can be represented by the + // `unsigned` argument of `zext`. + assert(i.getBitWidth() < 32 && + "Square of the constant cannot be represented on 64 bits"); + return llvm::APInt(2 * i.getBitWidth(), + i.abs().getZExtValue() * i.abs().getZExtValue()); +} + // Calculates the square root of `i` and rounds it to the next highest // integer value (i.e., the square of the result is guaranteed to be // greater or equal to `i`). static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) { llvm::APInt res = i.sqrt(); - llvm::APInt resSq = APIntWidthExtendUSq(res); + llvm::APInt resSq = APIntWidthExtendUnsignedSq(res); if (APIntWidthExtendULT(resSq, i)) return APIntWidthExtendUAdd(res, llvm::APInt{1, 1, false}); @@ -234,7 +244,7 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp, llvm::APInt accu{1, 0, false}; for (llvm::APInt val : denseVals.getValues()) { - llvm::APInt valSqNorm = APIntWidthExtendUSq(val); + llvm::APInt valSqNorm = APIntWidthExtendSqForConstant(val); llvm::APInt mulSqNorm = APIntWidthExtendUMul(valSqNorm, eNorm); accu = APIntWidthExtendUAdd(accu, mulSqNorm); } @@ -258,8 +268,9 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy, unsigned elWidth = tTy.getElementTypeBitWidth(); - llvm::APInt maxVal = APInt::getMaxValue(elWidth); - llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal); + llvm::APInt maxVal = APInt::getSignedMaxValue(elWidth); + llvm::APInt maxValSq = APIntWidthExtendUnsignedSq(maxVal); + llvm::APInt maxMulSqNorm = APIntWidthExtendUMul(maxValSq, eNorm); // Calculate number of bits for APInt to store number of elements @@ -272,6 +283,30 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy, return APIntWidthExtendUMul(maxMulSqNorm, nEltsAP); } +// Returns the squared 2-norm of the maximum value of the dense values. +static llvm::APInt maxIntNorm2Sq(mlir::DenseIntElementsAttr denseVals) { + // For a constant operand use actual constant to calculate 2-norm + llvm::APInt maxCst = denseVals.getFlatValue(0); + for (int64_t i = 0; i < denseVals.getNumElements(); i++) { + llvm::APInt iCst = denseVals.getFlatValue(i); + if (maxCst.ult(iCst)) { + maxCst = iCst; + } + } + return APIntWidthExtendSqForConstant(maxCst); +} + +// Returns the squared 2-norm for a dynamic integer by conservatively +// assuming that the integer's value is the maximum for the integer +// width. +static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) { + assert(t.isSignlessInteger() && "Type must be a signless integer type"); + assert(std::numeric_limits::max() - t.getIntOrFloatBitWidth() > 1); + + llvm::APInt maxVal = APInt::getSignedMaxValue(t.getIntOrFloatBitWidth()); + return APIntWidthExtendUnsignedSq(maxVal); +} + // Calculates the squared Minimal Arithmetic Noise Padding of an // `FHELinalg.dot_eint_int` operation. static llvm::APInt getSqMANP( @@ -306,18 +341,6 @@ static llvm::APInt getSqMANP( } } -// Returns the squared 2-norm for a dynamic integer by conservatively -// assuming that the integer's value is the maximum for the integer -// width. -static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) { - assert(t.isSignlessInteger() && "Type must be a signless integer type"); - assert(std::numeric_limits::max() - t.getIntOrFloatBitWidth() > 1); - - llvm::APInt maxVal{t.getIntOrFloatBitWidth() + 1, 1, false}; - maxVal <<= t.getIntOrFloatBitWidth(); - return APIntWidthExtendUSq(maxVal); -} - // Calculates the squared Minimal Arithmetic Noise Padding of an // `FHE.add_eint_int` operation. static llvm::APInt getSqMANP( @@ -343,7 +366,7 @@ static llvm::APInt getSqMANP( if (cstOp) { // For a constant operand use actual constant to calculate 2-norm mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); - sqNorm = APIntWidthExtendUSq(attr.getValue()); + sqNorm = APIntWidthExtendSqForConstant(attr.getValue()); } else { // For a dynamic operand conservatively assume that the value is // the maximum for the integer width @@ -395,7 +418,7 @@ static llvm::APInt getSqMANP( if (cstOp) { // For constant plaintext operands simply use the constant value mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); - sqNorm = APIntWidthExtendUSq(attr.getValue()); + sqNorm = APIntWidthExtendSqForConstant(attr.getValue()); } else { // For dynamic plaintext operands conservatively assume that the integer has // its maximum possible value @@ -445,7 +468,7 @@ static llvm::APInt getSqMANP( if (cstOp) { // For a constant operand use actual constant to calculate 2-norm mlir::IntegerAttr attr = cstOp->getAttrOfType("value"); - sqNorm = APIntWidthExtendUSq(attr.getValue()); + sqNorm = APIntWidthExtendSqForConstant(attr.getValue()); } else { // For a dynamic operand conservatively assume that the value is // the maximum for the integer width @@ -486,14 +509,7 @@ static llvm::APInt getSqMANP( if (denseVals) { // For a constant operand use actual constant to calculate 2-norm - llvm::APInt maxCst = denseVals.getFlatValue(0); - for (int64_t i = 0; i < denseVals.getNumElements(); i++) { - llvm::APInt iCst = denseVals.getFlatValue(i); - if (maxCst.ult(iCst)) { - maxCst = iCst; - } - } - sqNorm = APIntWidthExtendUSq(maxCst); + sqNorm = maxIntNorm2Sq(denseVals); } else { // For a dynamic operand conservatively assume that the value is // the maximum for the integer width @@ -548,15 +564,7 @@ static llvm::APInt getSqMANP( : nullptr; if (denseVals) { - // For a constant operand use actual constant to calculate 2-norm - llvm::APInt maxCst = denseVals.getFlatValue(0); - for (int64_t i = 0; i < denseVals.getNumElements(); i++) { - llvm::APInt iCst = denseVals.getFlatValue(i); - if (maxCst.ult(iCst)) { - maxCst = iCst; - } - } - sqNorm = APIntWidthExtendUSq(maxCst); + sqNorm = maxIntNorm2Sq(denseVals); } else { // For dynamic plaintext operands conservatively assume that the integer has // its maximum possible value @@ -612,14 +620,7 @@ static llvm::APInt getSqMANP( if (denseVals) { // For a constant operand use actual constant to calculate 2-norm - llvm::APInt maxCst = denseVals.getFlatValue(0); - for (int64_t i = 0; i < denseVals.getNumElements(); i++) { - llvm::APInt iCst = denseVals.getFlatValue(i); - if (maxCst.ult(iCst)) { - maxCst = iCst; - } - } - sqNorm = APIntWidthExtendUSq(maxCst); + sqNorm = maxIntNorm2Sq(denseVals); } else { // For a dynamic operand conservatively assume that the value is // the maximum for the integer width @@ -639,7 +640,7 @@ static llvm::APInt computeVectorNorm( elementSelector[axis] = i; llvm::APInt weight = denseValues.getValue(elementSelector); - llvm::APInt weightNorm = APIntWidthExtendUSq(weight); + llvm::APInt weightNorm = APIntWidthExtendSqForConstant(weight); llvm::APInt multiplicationNorm = APIntWidthExtendUMul(encryptedOperandNorm, weightNorm); @@ -749,7 +750,7 @@ static llvm::APInt getSqMANP( for (int64_t n = 0; n < N; n++) { llvm::APInt cst = denseVals.getValue({(uint64_t)n, (uint64_t)p}); - llvm::APInt rhsNorm = APIntWidthExtendUSq(cst); + llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst); llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); } @@ -765,7 +766,7 @@ static llvm::APInt getSqMANP( for (int64_t i = 0; i < N; i++) { llvm::APInt cst = denseVals.getFlatValue(i); - llvm::APInt rhsNorm = APIntWidthExtendUSq(cst); + llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst); llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); } @@ -849,7 +850,7 @@ static llvm::APInt getSqMANP( for (int64_t n = 0; n < N; n++) { llvm::APInt cst = denseVals.getValue({(uint64_t)m, (uint64_t)n}); - llvm::APInt lhsNorm = APIntWidthExtendUSq(cst); + llvm::APInt lhsNorm = APIntWidthExtendSqForConstant(cst); llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); } @@ -865,7 +866,7 @@ static llvm::APInt getSqMANP( for (int64_t i = 0; i < N; i++) { llvm::APInt cst = denseVals.getFlatValue(i); - llvm::APInt lhsNorm = APIntWidthExtendUSq(cst); + llvm::APInt lhsNorm = APIntWidthExtendSqForConstant(cst); llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm); accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); } @@ -1106,14 +1107,14 @@ static llvm::APInt getSqMANP( // If there is a bias, start accumulating from its norm if (hasBias && biasDenseVals) { llvm::APInt cst = biasDenseVals.getFlatValue(f); - tmpNorm = APIntWidthExtendUSq(cst); + tmpNorm = APIntWidthExtendSqForConstant(cst); } for (uint64_t c = 0; c < C; c++) { for (uint64_t h = 0; h < H; h++) { for (uint64_t w = 0; w < W; w++) { llvm::APInt cst = weightDenseVals.getValue({f, c, h, w}); - llvm::APInt weightNorm = APIntWidthExtendUSq(cst); + llvm::APInt weightNorm = APIntWidthExtendSqForConstant(cst); llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm); tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); } @@ -1138,7 +1139,7 @@ static llvm::APInt getSqMANP( llvm::APInt maxNorm = tmpNorm; for (uint64_t f = 0; f < F; f++) { llvm::APInt cst = biasDenseVals.getFlatValue(f); - llvm::APInt currentNorm = APIntWidthExtendUSq(cst); + llvm::APInt currentNorm = APIntWidthExtendSqForConstant(cst); currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm); maxNorm = APIntUMax(currentNorm, maxNorm); } diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir index 6d33c49e3..1f26728fc 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP.mlir @@ -34,7 +34,8 @@ func @single_cst_add_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> func @single_dyn_add_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // sqrt(1 + (2^2-1)^2) = 3.16 + // CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.add_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -66,7 +67,8 @@ func @single_cst_sub_int_eint(%e: !FHE.eint<2>) -> !FHE.eint<2> func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> + // sqrt(1 + (2^2-1)^2) = 3.16 + // CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2> %0 = "FHE.sub_int_eint"(%i, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -98,7 +100,8 @@ func @single_cst_mul_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2> func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2> { - // CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> + // sqrt(1 + (2^2-1)^2) = 3 + // CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> %0 = "FHE.mul_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2> return %0 : !FHE.eint<2> diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir index 1d1983f9a..e157b2513 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s +// RUN: concretecompiler --passes canonicalize --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { @@ -12,9 +12,22 @@ func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint< // ----- +func @single_cst_add_eint_int_from_cst_elements(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> +{ + %cst1 = arith.constant 1 : i3 + %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> + + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + %0 = "FHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + + return %0 : tensor<8x!FHE.eint<2>> +} + +// ----- func @single_dyn_add_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // sqrt(1 + (2^2-1)^2) = 3..16 + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.add_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -44,6 +57,19 @@ func @single_cst_sub_int_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint< // ----- +func @single_cst_sub_int_eint_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> +{ + %cst1 = arith.constant 1 : i3 + %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> + + // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + %0 = "FHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + + return %0 : tensor<8x!FHE.eint<2>> +} + +// ----- + func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> { // CHECK: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> @@ -56,7 +82,8 @@ func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> func @single_dyn_sub_int_eint(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> + // sqrt(1 + (2^2-1)^2) = 3.16 + // CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -76,9 +103,23 @@ func @single_cst_mul_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint< // ----- +func @single_cst_mul_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> +{ + %cst1 = arith.constant 1 : i3 + %cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3> + + // %0 = "FHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + %0 = "FHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + + return %0 : tensor<8x!FHE.eint<2>> +} + +// ----- + func @single_dyn_mul_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>> { - // CHECK: %[[ret:.*]] = "FHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // sqrt(1 * (2^2-1)^2) = 3.16 + // CHECK: %[[ret:.*]] = "FHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.mul_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> return %0 : tensor<8x!FHE.eint<2>> @@ -86,21 +127,21 @@ func @single_dyn_mul_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> t // ----- -func @chain_add_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>> +func @chain_add_eint_int(%e: tensor<8x!FHE.eint<3>>) -> tensor<8x!FHE.eint<3>> { - %cst0 = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3> - %cst1 = arith.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi3> - %cst2 = arith.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi3> - %cst3 = arith.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi3> - // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - %0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - %1 = "FHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - %2 = "FHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - %3 = "FHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> - return %3 : tensor<8x!FHE.eint<2>> + %cst0 = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi4> + %cst1 = arith.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi4> + %cst2 = arith.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi4> + %cst3 = arith.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi4> + // CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + %0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + %1 = "FHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + %2 = "FHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + // CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + %3 = "FHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>> + return %3 : tensor<8x!FHE.eint<3>> } // ----- @@ -132,7 +173,7 @@ func @apply_lookup_table(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3 func @apply_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<3>> { %lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64> - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> // CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>> %res = "FHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>> @@ -151,7 +192,7 @@ func @apply_multi_lookup_table(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<3x3x4 // ----- func @apply_multi_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> { - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>> // CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_multi_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> %res = "FHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> @@ -173,12 +214,13 @@ func @single_cst_dot(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> return %0 : !FHE.eint<2> } + // ----- func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FHE.eint<2> { - // sqrt(1*(2^3-1)^2*4) = 14 - // CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> + // sqrt(1*(2^2-1)^2*4) = 16 + // CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> %0 = "FHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %0 : !FHE.eint<2> @@ -188,13 +230,13 @@ func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FHE.ein func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2> { - // sqrt((2^3)^2*1) = sqrt(64) = 8 - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} + // sqrt((2^2-1)^2*1) = sqrt(9) = 3 + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> - %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3> - // sqrt(1^2*64 + 2^2*64 + 3^2*64 + 4^2*64) = sqrt(1920) = 43.8178046 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 44 : ui{{[[0-9]+}}} + %cst = arith.constant dense<[1, 2, 3, -1]> : tensor<4xi3> + // sqrt(1^2*9 + 2^2*9 + 3^2*9 + 1^2*9) = sqrt(135) = 12 + // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 12 : ui{{[[0-9]+}}} %1 = "FHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -204,12 +246,12 @@ func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> ! func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2> { - // sqrt((2^3)^2*1) = sqrt(64) = 8 - // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} + // sqrt((2^2-1)^2*1) = sqrt(9) = 3 + // CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} %0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> - // sqrt(4*(2^3-1)^2*64) = sqrt(12544) = 112 - // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 112 : ui{{[[0-9]+}}} + // sqrt(4*(2^2-1)^2*9) = sqrt(324) = 18 + // CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 18 : ui{{[[0-9]+}}} %1 = "FHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2> return %1 : !FHE.eint<2> @@ -224,10 +266,10 @@ func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> ! func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> { // p = 0 // acc = manp(0) = 1 - // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 - // manp(add_eint(mul, acc)) = 64 + 1 = 65 - // ceil(sqrt(65)) = 9 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}} + // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 + // manp(add_eint(mul, acc)) = 9 + 1 = 10 + // ceil(sqrt(65)) = 4 + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -237,13 +279,13 @@ func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2 func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> { // p = 0 // acc = manp(0) = 1 - // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 - // manp(add_eint(mul, acc)) = 64 + 1 = 65 + // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 + // manp(add_eint(mul, acc)) = 9 + 1 = 10 // p = 1 - // manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 - // manp(add_eint(mul, acc)) = 64 + 65 = 129 - // ceil(sqrt(129)) = 12 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}} + // manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 + // manp(add_eint(mul, acc)) = 10 + 9 = 19 + // ceil(sqrt(19)) = 5 + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -507,10 +549,10 @@ func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<7>> { func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> { // p = 0 // acc = manp(0) = 1 - // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 - // manp(add_eint(mul, acc)) = 64 + 1 = 65 - // ceil(sqrt(65)) = 9 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}} + // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 + // manp(add_eint(mul, acc)) = 64 + 1 = 10 + // ceil(sqrt(65)) = 4 + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -520,13 +562,13 @@ func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> { // p = 0 // acc = manp(0) = 1 - // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 - // manp(add_eint(mul, acc)) = 64 + 1 = 65 + // mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 + // manp(add_eint(mul, acc)) = 64 + 1 = 10 // p = 1 - // manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64 - // manp(add_eint(mul, acc)) = 64 + 65 = 129 + // manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9 + // manp(add_eint(mul, acc)) = 10 + 9 = 19 // ceil(sqrt(129)) = 12 - // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}} + // CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}} %1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x2xi3>, tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> return %1 : tensor<3x2x!FHE.eint<2>> } @@ -969,7 +1011,7 @@ func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> te func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 129 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 64 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> @@ -980,7 +1022,7 @@ func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1 func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { %bias = arith.constant dense<[5]> : tensor<1xi3> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 17 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 7 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> @@ -990,7 +1032,7 @@ func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x // ----- func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 18 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 7 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> @@ -1000,7 +1042,7 @@ func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: te // ----- func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 29 : ui{{[0-9]+}} + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 11 : ui{{[0-9]+}} %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> } : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_tensor.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_tensor.mlir index d5f0d631a..59eb6ed62 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_tensor.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_tensor.mlir @@ -40,17 +40,14 @@ func @tensor_extract_1(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> // ----- -func @tensor_extract_2(%a: !FHE.eint<2>) -> !FHE.eint<2> +func @tensor_extract_2(%a: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2> { %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : i3 - - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> - %0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<2>, i3) -> !FHE.eint<2> - // CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> - %1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!FHE.eint<2>> - // CHECK: %[[ret:.*]] = tensor.extract %[[V1]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> - %2 = tensor.extract %1[%c1] : tensor<4x!FHE.eint<2>> + %c3 = arith.constant dense<3> : tensor<4xi3> + // CHECK: %[[V0:.*]] = "FHELinalg.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} + %0 = "FHELinalg.add_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> + // CHECK: %[[ret:.*]] = tensor.extract %[[V0]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> + %2 = tensor.extract %0[%c1] : tensor<4x!FHE.eint<2>> return %2 : !FHE.eint<2> } @@ -67,16 +64,15 @@ func @tensor_extract_slice_1(%t: tensor<2x10x!FHE.eint<2>>) -> tensor<1x5x!FHE.e // ----- -func @tensor_extract_slice_2(%a: !FHE.eint<2>) -> tensor<2x!FHE.eint<2>> +func @tensor_extract_slice_2(%a: tensor<4x!FHE.eint<2>>) -> tensor<2x!FHE.eint<2>> { - %c3 = arith.constant 3 : i3 + %c3 = arith.constant dense <3> : tensor<4xi3> - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2> - %0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<2>, i3) -> !FHE.eint<2> - // CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>> - %1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!FHE.eint<2>> - // CHECK: tensor.extract_slice %[[V1]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> - %2 = tensor.extract_slice %1[2] [2] [1] : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> + // CHECK: %[[V0:.*]] = "FHELinalg.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} + %0 = "FHELinalg.add_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>> + + // CHECK: tensor.extract_slice %[[V0]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> + %2 = tensor.extract_slice %0[2] [2] [1] : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>> return %2 : tensor<2x!FHE.eint<2>> } @@ -93,36 +89,6 @@ func @tensor_insert_slice_1(%t0: tensor<2x10x!FHE.eint<2>>, %t1: tensor<2x2x!FHE // ----- -func @tensor_insert_slice_2(%a: !FHE.eint<5>) -> tensor<4x!FHE.eint<5>> -{ - %c3 = arith.constant 3 : i6 - %c6 = arith.constant 6 : i6 - - // CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c3:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<5>, i6) -> !FHE.eint<5> - %v0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<5>, i6) -> !FHE.eint<5> - // CHECK: %[[V1:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c6:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (!FHE.eint<5>, i6) -> !FHE.eint<5> - %v1 = "FHE.add_eint_int"(%a, %c6) : (!FHE.eint<5>, i6) -> !FHE.eint<5> - - // CHECK: %[[T0:.*]] = tensor.from_elements %[[V0]], %[[V0]], %[[V0]], %[[V0]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<5>> - %t0 = tensor.from_elements %v0, %v0, %v0, %v0 : tensor<4x!FHE.eint<5>> - - // CHECK: %[[T1:.*]] = tensor.from_elements %[[V1]], %[[V1]] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>> - %t1 = tensor.from_elements %v1, %v1 : tensor<2x!FHE.eint<5>> - - // CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[T0]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>> - %t2 = tensor.insert_slice %t1 into %t0[0] [2] [1] : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>> - - // CHECK: %[[T3:.*]] = tensor.from_elements %[[V0]], %[[V0]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>> - %t3 = tensor.from_elements %v0, %v0 : tensor<2x!FHE.eint<5>> - - // CHECK: %[[T4:.*]] = tensor.insert_slice %[[T3]] into %[[T2]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>> - %t4 = tensor.insert_slice %t3 into %t2[0] [2] [1] : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>> - - return %t0 : tensor<4x!FHE.eint<5>> -} - -// ----- - func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8x!FHE.eint<6>> { // CHECK: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}} %0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!FHE.eint<6>> into tensor<2x8x!FHE.eint<6>> @@ -133,9 +99,9 @@ func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8x!FHE func @tensor_collapse_shape_2(%a: tensor<2x2x4x!FHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!FHE.eint<2>> { - // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}} + // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 4 : ui{{[0-9]+}}} %0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x2x4x!FHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!FHE.eint<2>> - // CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}} + // CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 4 : ui{{[0-9]+}}} %1 = linalg.tensor_collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!FHE.eint<2>> into tensor<2x8x!FHE.eint<2>> return %1 : tensor<2x8x!FHE.eint<2>> } @@ -152,9 +118,9 @@ func @tensor_expand_shape_1(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x!FHE.e func @tensor_expand_shape_2(%a: tensor<2x8x!FHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!FHE.eint<2>> { - // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}} + // CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 4 : ui{{[0-9]+}}} %0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x8x!FHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!FHE.eint<2>> - // CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}} + // CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 4 : ui{{[0-9]+}}} %1 = linalg.tensor_expand_shape %0 [[0],[1,2]] : tensor<2x8x!FHE.eint<2>> into tensor<2x2x4x!FHE.eint<2>> return %1 : tensor<2x2x4x!FHE.eint<2>> } \ No newline at end of file