fix(compiler): Use the absolute value when computing the square of a constant in MANP analysis

The computation of the norm2 should take care of the sign of the constant to compute the square of this constant.
This commit is contained in:
Quentin Bourgerie
2022-05-20 13:39:22 +02:00
parent 6532c8f449
commit ae9a04cd56
4 changed files with 173 additions and 161 deletions

View File

@@ -183,7 +183,7 @@ static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) {
// Calculates the square of `i`. The bit width `i` is extended in
// order to guarantee that the product fits into the resulting
// `APInt`.
static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) {
static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) {
// Make sure the required number of bits can be represented by the
// `unsigned` argument of `zext`.
assert(i.getBitWidth() < std::numeric_limits<unsigned>::max() / 2 &&
@@ -194,12 +194,22 @@ static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) {
return ie * ie;
}
// Calculates the square of the absolute value of `i`.
static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) {
// Make sure the required number of bits can be represented by the
// `unsigned` argument of `zext`.
assert(i.getBitWidth() < 32 &&
"Square of the constant cannot be represented on 64 bits");
return llvm::APInt(2 * i.getBitWidth(),
i.abs().getZExtValue() * i.abs().getZExtValue());
}
// Calculates the square root of `i` and rounds it to the next highest
// integer value (i.e., the square of the result is guaranteed to be
// greater or equal to `i`).
static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) {
llvm::APInt res = i.sqrt();
llvm::APInt resSq = APIntWidthExtendUSq(res);
llvm::APInt resSq = APIntWidthExtendUnsignedSq(res);
if (APIntWidthExtendULT(resSq, i))
return APIntWidthExtendUAdd(res, llvm::APInt{1, 1, false});
@@ -234,7 +244,7 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp,
llvm::APInt accu{1, 0, false};
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
llvm::APInt valSqNorm = APIntWidthExtendUSq(val);
llvm::APInt valSqNorm = APIntWidthExtendSqForConstant(val);
llvm::APInt mulSqNorm = APIntWidthExtendUMul(valSqNorm, eNorm);
accu = APIntWidthExtendUAdd(accu, mulSqNorm);
}
@@ -258,8 +268,9 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy,
unsigned elWidth = tTy.getElementTypeBitWidth();
llvm::APInt maxVal = APInt::getMaxValue(elWidth);
llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal);
llvm::APInt maxVal = APInt::getSignedMaxValue(elWidth);
llvm::APInt maxValSq = APIntWidthExtendUnsignedSq(maxVal);
llvm::APInt maxMulSqNorm = APIntWidthExtendUMul(maxValSq, eNorm);
// Calculate number of bits for APInt to store number of elements
@@ -272,6 +283,30 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy,
return APIntWidthExtendUMul(maxMulSqNorm, nEltsAP);
}
// Returns the squared 2-norm of the maximum value of the dense values.
static llvm::APInt maxIntNorm2Sq(mlir::DenseIntElementsAttr denseVals) {
// For a constant operand use actual constant to calculate 2-norm
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
if (maxCst.ult(iCst)) {
maxCst = iCst;
}
}
return APIntWidthExtendSqForConstant(maxCst);
}
// Returns the squared 2-norm for a dynamic integer by conservatively
// assuming that the integer's value is the maximum for the integer
// width.
static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
assert(t.isSignlessInteger() && "Type must be a signless integer type");
assert(std::numeric_limits<unsigned>::max() - t.getIntOrFloatBitWidth() > 1);
llvm::APInt maxVal = APInt::getSignedMaxValue(t.getIntOrFloatBitWidth());
return APIntWidthExtendUnsignedSq(maxVal);
}
// Calculates the squared Minimal Arithmetic Noise Padding of an
// `FHELinalg.dot_eint_int` operation.
static llvm::APInt getSqMANP(
@@ -306,18 +341,6 @@ static llvm::APInt getSqMANP(
}
}
// Returns the squared 2-norm for a dynamic integer by conservatively
// assuming that the integer's value is the maximum for the integer
// width.
static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
assert(t.isSignlessInteger() && "Type must be a signless integer type");
assert(std::numeric_limits<unsigned>::max() - t.getIntOrFloatBitWidth() > 1);
llvm::APInt maxVal{t.getIntOrFloatBitWidth() + 1, 1, false};
maxVal <<= t.getIntOrFloatBitWidth();
return APIntWidthExtendUSq(maxVal);
}
// Calculates the squared Minimal Arithmetic Noise Padding of an
// `FHE.add_eint_int` operation.
static llvm::APInt getSqMANP(
@@ -343,7 +366,7 @@ static llvm::APInt getSqMANP(
if (cstOp) {
// For a constant operand use actual constant to calculate 2-norm
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
sqNorm = APIntWidthExtendUSq(attr.getValue());
sqNorm = APIntWidthExtendSqForConstant(attr.getValue());
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
@@ -395,7 +418,7 @@ static llvm::APInt getSqMANP(
if (cstOp) {
// For constant plaintext operands simply use the constant value
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
sqNorm = APIntWidthExtendUSq(attr.getValue());
sqNorm = APIntWidthExtendSqForConstant(attr.getValue());
} else {
// For dynamic plaintext operands conservatively assume that the integer has
// its maximum possible value
@@ -445,7 +468,7 @@ static llvm::APInt getSqMANP(
if (cstOp) {
// For a constant operand use actual constant to calculate 2-norm
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
sqNorm = APIntWidthExtendUSq(attr.getValue());
sqNorm = APIntWidthExtendSqForConstant(attr.getValue());
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
@@ -486,14 +509,7 @@ static llvm::APInt getSqMANP(
if (denseVals) {
// For a constant operand use actual constant to calculate 2-norm
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
if (maxCst.ult(iCst)) {
maxCst = iCst;
}
}
sqNorm = APIntWidthExtendUSq(maxCst);
sqNorm = maxIntNorm2Sq(denseVals);
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
@@ -548,15 +564,7 @@ static llvm::APInt getSqMANP(
: nullptr;
if (denseVals) {
// For a constant operand use actual constant to calculate 2-norm
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
if (maxCst.ult(iCst)) {
maxCst = iCst;
}
}
sqNorm = APIntWidthExtendUSq(maxCst);
sqNorm = maxIntNorm2Sq(denseVals);
} else {
// For dynamic plaintext operands conservatively assume that the integer has
// its maximum possible value
@@ -612,14 +620,7 @@ static llvm::APInt getSqMANP(
if (denseVals) {
// For a constant operand use actual constant to calculate 2-norm
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
if (maxCst.ult(iCst)) {
maxCst = iCst;
}
}
sqNorm = APIntWidthExtendUSq(maxCst);
sqNorm = maxIntNorm2Sq(denseVals);
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
@@ -639,7 +640,7 @@ static llvm::APInt computeVectorNorm(
elementSelector[axis] = i;
llvm::APInt weight = denseValues.getValue<llvm::APInt>(elementSelector);
llvm::APInt weightNorm = APIntWidthExtendUSq(weight);
llvm::APInt weightNorm = APIntWidthExtendSqForConstant(weight);
llvm::APInt multiplicationNorm =
APIntWidthExtendUMul(encryptedOperandNorm, weightNorm);
@@ -749,7 +750,7 @@ static llvm::APInt getSqMANP(
for (int64_t n = 0; n < N; n++) {
llvm::APInt cst =
denseVals.getValue<llvm::APInt>({(uint64_t)n, (uint64_t)p});
llvm::APInt rhsNorm = APIntWidthExtendUSq(cst);
llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst);
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
}
@@ -765,7 +766,7 @@ static llvm::APInt getSqMANP(
for (int64_t i = 0; i < N; i++) {
llvm::APInt cst = denseVals.getFlatValue<llvm::APInt>(i);
llvm::APInt rhsNorm = APIntWidthExtendUSq(cst);
llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst);
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
accNorm = APIntWidthExtendUAdd(mulNorm, accNorm);
}
@@ -849,7 +850,7 @@ static llvm::APInt getSqMANP(
for (int64_t n = 0; n < N; n++) {
llvm::APInt cst =
denseVals.getValue<llvm::APInt>({(uint64_t)m, (uint64_t)n});
llvm::APInt lhsNorm = APIntWidthExtendUSq(cst);
llvm::APInt lhsNorm = APIntWidthExtendSqForConstant(cst);
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
}
@@ -865,7 +866,7 @@ static llvm::APInt getSqMANP(
for (int64_t i = 0; i < N; i++) {
llvm::APInt cst = denseVals.getFlatValue<llvm::APInt>(i);
llvm::APInt lhsNorm = APIntWidthExtendUSq(cst);
llvm::APInt lhsNorm = APIntWidthExtendSqForConstant(cst);
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
accNorm = APIntWidthExtendUAdd(mulNorm, accNorm);
}
@@ -1106,14 +1107,14 @@ static llvm::APInt getSqMANP(
// If there is a bias, start accumulating from its norm
if (hasBias && biasDenseVals) {
llvm::APInt cst = biasDenseVals.getFlatValue<llvm::APInt>(f);
tmpNorm = APIntWidthExtendUSq(cst);
tmpNorm = APIntWidthExtendSqForConstant(cst);
}
for (uint64_t c = 0; c < C; c++) {
for (uint64_t h = 0; h < H; h++) {
for (uint64_t w = 0; w < W; w++) {
llvm::APInt cst =
weightDenseVals.getValue<llvm::APInt>({f, c, h, w});
llvm::APInt weightNorm = APIntWidthExtendUSq(cst);
llvm::APInt weightNorm = APIntWidthExtendSqForConstant(cst);
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
}
@@ -1138,7 +1139,7 @@ static llvm::APInt getSqMANP(
llvm::APInt maxNorm = tmpNorm;
for (uint64_t f = 0; f < F; f++) {
llvm::APInt cst = biasDenseVals.getFlatValue<llvm::APInt>(f);
llvm::APInt currentNorm = APIntWidthExtendUSq(cst);
llvm::APInt currentNorm = APIntWidthExtendSqForConstant(cst);
currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm);
maxNorm = APIntUMax(currentNorm, maxNorm);
}

View File

@@ -34,7 +34,8 @@ func @single_cst_add_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2>
func @single_dyn_add_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
{
// CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// sqrt(1 + (2^2-1)^2) = 3.16
// CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
%0 = "FHE.add_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
return %0 : !FHE.eint<2>
@@ -66,7 +67,8 @@ func @single_cst_sub_int_eint(%e: !FHE.eint<2>) -> !FHE.eint<2>
func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
{
// CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2>
// sqrt(1 + (2^2-1)^2) = 3.16
// CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2>
%0 = "FHE.sub_int_eint"(%i, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2>
return %0 : !FHE.eint<2>
@@ -98,7 +100,8 @@ func @single_cst_mul_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2>
func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
{
// CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// sqrt(1 + (2^2-1)^2) = 3
// CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
%0 = "FHE.mul_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
return %0 : !FHE.eint<2>

View File

@@ -1,4 +1,4 @@
// RUN: concretecompiler --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
// RUN: concretecompiler --passes canonicalize --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
@@ -12,9 +12,22 @@ func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<
// -----
func @single_cst_add_eint_int_from_cst_elements(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
%cst1 = arith.constant 1 : i3
%cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3>
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
}
// -----
func @single_dyn_add_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// sqrt(1 + (2^2-1)^2) = 3..16
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.add_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
@@ -44,6 +57,19 @@ func @single_cst_sub_int_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<
// -----
func @single_cst_sub_int_eint_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
%cst1 = arith.constant 1 : i3
%cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3>
// CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
}
// -----
func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
@@ -56,7 +82,8 @@ func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
func @single_dyn_sub_int_eint(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
// sqrt(1 + (2^2-1)^2) = 3.16
// CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
@@ -76,9 +103,23 @@ func @single_cst_mul_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<
// -----
func @single_cst_mul_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
%cst1 = arith.constant 1 : i3
%cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3>
// %0 = "FHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
}
// -----
func @single_dyn_mul_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "FHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// sqrt(1 * (2^2-1)^2) = 3.16
// CHECK: %[[ret:.*]] = "FHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.mul_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
@@ -86,21 +127,21 @@ func @single_dyn_mul_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> t
// -----
func @chain_add_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
func @chain_add_eint_int(%e: tensor<8x!FHE.eint<3>>) -> tensor<8x!FHE.eint<3>>
{
%cst0 = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
%cst1 = arith.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi3>
%cst2 = arith.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi3>
%cst3 = arith.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi3>
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%1 = "FHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%2 = "FHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%3 = "FHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
return %3 : tensor<8x!FHE.eint<2>>
%cst0 = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi4>
%cst1 = arith.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi4>
%cst2 = arith.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi4>
%cst3 = arith.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi4>
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
%0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
%1 = "FHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
%2 = "FHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
%3 = "FHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
return %3 : tensor<8x!FHE.eint<3>>
}
// -----
@@ -132,7 +173,7 @@ func @apply_lookup_table(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3
func @apply_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<3>> {
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64>
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>>
%res = "FHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>>
@@ -151,7 +192,7 @@ func @apply_multi_lookup_table(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<3x3x4
// -----
func @apply_multi_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> {
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
// CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_multi_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>>
%res = "FHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>>
@@ -173,12 +214,13 @@ func @single_cst_dot(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2>
return %0 : !FHE.eint<2>
}
// -----
func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FHE.eint<2>
{
// sqrt(1*(2^3-1)^2*4) = 14
// CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
// sqrt(1*(2^2-1)^2*4) = 16
// CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
%0 = "FHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
return %0 : !FHE.eint<2>
@@ -188,13 +230,13 @@ func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FHE.ein
func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2>
{
// sqrt((2^3)^2*1) = sqrt(64) = 8
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}}
// sqrt((2^2-1)^2*1) = sqrt(9) = 3
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}}
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3>
// sqrt(1^2*64 + 2^2*64 + 3^2*64 + 4^2*64) = sqrt(1920) = 43.8178046
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 44 : ui{{[[0-9]+}}}
%cst = arith.constant dense<[1, 2, 3, -1]> : tensor<4xi3>
// sqrt(1^2*9 + 2^2*9 + 3^2*9 + 1^2*9) = sqrt(135) = 12
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 12 : ui{{[[0-9]+}}}
%1 = "FHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
return %1 : !FHE.eint<2>
@@ -204,12 +246,12 @@ func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !
func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2>
{
// sqrt((2^3)^2*1) = sqrt(64) = 8
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}}
// sqrt((2^2-1)^2*1) = sqrt(9) = 3
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}}
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
// sqrt(4*(2^3-1)^2*64) = sqrt(12544) = 112
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 112 : ui{{[[0-9]+}}}
// sqrt(4*(2^2-1)^2*9) = sqrt(324) = 18
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 18 : ui{{[[0-9]+}}}
%1 = "FHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
return %1 : !FHE.eint<2>
@@ -224,10 +266,10 @@ func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !
func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
// p = 0
// acc = manp(0) = 1
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
// manp(add_eint(mul, acc)) = 64 + 1 = 65
// ceil(sqrt(65)) = 9
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 9 + 1 = 10
// ceil(sqrt(65)) = 4
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -237,13 +279,13 @@ func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2
func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
// p = 0
// acc = manp(0) = 1
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
// manp(add_eint(mul, acc)) = 64 + 1 = 65
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 9 + 1 = 10
// p = 1
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
// manp(add_eint(mul, acc)) = 64 + 65 = 129
// ceil(sqrt(129)) = 12
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}}
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 10 + 9 = 19
// ceil(sqrt(19)) = 5
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -507,10 +549,10 @@ func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<7>> {
func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> {
// p = 0
// acc = manp(0) = 1
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
// manp(add_eint(mul, acc)) = 64 + 1 = 65
// ceil(sqrt(65)) = 9
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 64 + 1 = 10
// ceil(sqrt(65)) = 4
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -520,13 +562,13 @@ func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint
func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> {
// p = 0
// acc = manp(0) = 1
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
// manp(add_eint(mul, acc)) = 64 + 1 = 65
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 64 + 1 = 10
// p = 1
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
// manp(add_eint(mul, acc)) = 64 + 65 = 129
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
// manp(add_eint(mul, acc)) = 10 + 9 = 19
// ceil(sqrt(129)) = 12
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}}
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x2xi3>, tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
return %1 : tensor<3x2x!FHE.eint<2>>
}
@@ -969,7 +1011,7 @@ func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> te
func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 129 : ui{{[0-9]+}}
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 64 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
@@ -980,7 +1022,7 @@ func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1
func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
%bias = arith.constant dense<[5]> : tensor<1xi3>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 17 : ui{{[0-9]+}}
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 7 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
@@ -990,7 +1032,7 @@ func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x
// -----
func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 18 : ui{{[0-9]+}}
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 7 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
@@ -1000,7 +1042,7 @@ func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: te
// -----
func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> {
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 29 : ui{{[0-9]+}}
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 11 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>>

View File

@@ -40,17 +40,14 @@ func @tensor_extract_1(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2>
// -----
func @tensor_extract_2(%a: !FHE.eint<2>) -> !FHE.eint<2>
func @tensor_extract_2(%a: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2>
{
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : i3
// CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
%0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>>
%1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!FHE.eint<2>>
// CHECK: %[[ret:.*]] = tensor.extract %[[V1]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>>
%2 = tensor.extract %1[%c1] : tensor<4x!FHE.eint<2>>
%c3 = arith.constant dense<3> : tensor<4xi3>
// CHECK: %[[V0:.*]] = "FHELinalg.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
%0 = "FHELinalg.add_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
// CHECK: %[[ret:.*]] = tensor.extract %[[V0]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>>
%2 = tensor.extract %0[%c1] : tensor<4x!FHE.eint<2>>
return %2 : !FHE.eint<2>
}
@@ -67,16 +64,15 @@ func @tensor_extract_slice_1(%t: tensor<2x10x!FHE.eint<2>>) -> tensor<1x5x!FHE.e
// -----
func @tensor_extract_slice_2(%a: !FHE.eint<2>) -> tensor<2x!FHE.eint<2>>
func @tensor_extract_slice_2(%a: tensor<4x!FHE.eint<2>>) -> tensor<2x!FHE.eint<2>>
{
%c3 = arith.constant 3 : i3
%c3 = arith.constant dense <3> : tensor<4xi3>
// CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
%0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>>
%1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!FHE.eint<2>>
// CHECK: tensor.extract_slice %[[V1]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>>
%2 = tensor.extract_slice %1[2] [2] [1] : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>>
// CHECK: %[[V0:.*]] = "FHELinalg.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
%0 = "FHELinalg.add_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
// CHECK: tensor.extract_slice %[[V0]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>>
%2 = tensor.extract_slice %0[2] [2] [1] : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>>
return %2 : tensor<2x!FHE.eint<2>>
}
@@ -93,36 +89,6 @@ func @tensor_insert_slice_1(%t0: tensor<2x10x!FHE.eint<2>>, %t1: tensor<2x2x!FHE
// -----
func @tensor_insert_slice_2(%a: !FHE.eint<5>) -> tensor<4x!FHE.eint<5>>
{
%c3 = arith.constant 3 : i6
%c6 = arith.constant 6 : i6
// CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c3:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<5>, i6) -> !FHE.eint<5>
%v0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<5>, i6) -> !FHE.eint<5>
// CHECK: %[[V1:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c6:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (!FHE.eint<5>, i6) -> !FHE.eint<5>
%v1 = "FHE.add_eint_int"(%a, %c6) : (!FHE.eint<5>, i6) -> !FHE.eint<5>
// CHECK: %[[T0:.*]] = tensor.from_elements %[[V0]], %[[V0]], %[[V0]], %[[V0]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<5>>
%t0 = tensor.from_elements %v0, %v0, %v0, %v0 : tensor<4x!FHE.eint<5>>
// CHECK: %[[T1:.*]] = tensor.from_elements %[[V1]], %[[V1]] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>>
%t1 = tensor.from_elements %v1, %v1 : tensor<2x!FHE.eint<5>>
// CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[T0]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>>
%t2 = tensor.insert_slice %t1 into %t0[0] [2] [1] : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>>
// CHECK: %[[T3:.*]] = tensor.from_elements %[[V0]], %[[V0]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>>
%t3 = tensor.from_elements %v0, %v0 : tensor<2x!FHE.eint<5>>
// CHECK: %[[T4:.*]] = tensor.insert_slice %[[T3]] into %[[T2]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>>
%t4 = tensor.insert_slice %t3 into %t2[0] [2] [1] : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>>
return %t0 : tensor<4x!FHE.eint<5>>
}
// -----
func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8x!FHE.eint<6>> {
// CHECK: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}}
%0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!FHE.eint<6>> into tensor<2x8x!FHE.eint<6>>
@@ -133,9 +99,9 @@ func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8x!FHE
func @tensor_collapse_shape_2(%a: tensor<2x2x4x!FHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!FHE.eint<2>>
{
// CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}}
// CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 4 : ui{{[0-9]+}}}
%0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x2x4x!FHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!FHE.eint<2>>
// CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}}
// CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 4 : ui{{[0-9]+}}}
%1 = linalg.tensor_collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!FHE.eint<2>> into tensor<2x8x!FHE.eint<2>>
return %1 : tensor<2x8x!FHE.eint<2>>
}
@@ -152,9 +118,9 @@ func @tensor_expand_shape_1(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x!FHE.e
func @tensor_expand_shape_2(%a: tensor<2x8x!FHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!FHE.eint<2>>
{
// CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}}
// CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 4 : ui{{[0-9]+}}}
%0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x8x!FHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!FHE.eint<2>>
// CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}}
// CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 4 : ui{{[0-9]+}}}
%1 = linalg.tensor_expand_shape %0 [[0],[1,2]] : tensor<2x8x!FHE.eint<2>> into tensor<2x2x4x!FHE.eint<2>>
return %1 : tensor<2x2x4x!FHE.eint<2>>
}