mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): Use the absolute value when computing the square of a constant in MANP analysis
The computation of the norm2 should take care of the sign of the constant to compute the square of this constant.
This commit is contained in:
@@ -183,7 +183,7 @@ static llvm::APInt APIntUMax(const llvm::APInt &lhs, const llvm::APInt &rhs) {
|
||||
// Calculates the square of `i`. The bit width `i` is extended in
|
||||
// order to guarantee that the product fits into the resulting
|
||||
// `APInt`.
|
||||
static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) {
|
||||
static llvm::APInt APIntWidthExtendUnsignedSq(const llvm::APInt &i) {
|
||||
// Make sure the required number of bits can be represented by the
|
||||
// `unsigned` argument of `zext`.
|
||||
assert(i.getBitWidth() < std::numeric_limits<unsigned>::max() / 2 &&
|
||||
@@ -194,12 +194,22 @@ static llvm::APInt APIntWidthExtendUSq(const llvm::APInt &i) {
|
||||
return ie * ie;
|
||||
}
|
||||
|
||||
// Calculates the square of the absolute value of `i`.
|
||||
static llvm::APInt APIntWidthExtendSqForConstant(const llvm::APInt &i) {
|
||||
// Make sure the required number of bits can be represented by the
|
||||
// `unsigned` argument of `zext`.
|
||||
assert(i.getBitWidth() < 32 &&
|
||||
"Square of the constant cannot be represented on 64 bits");
|
||||
return llvm::APInt(2 * i.getBitWidth(),
|
||||
i.abs().getZExtValue() * i.abs().getZExtValue());
|
||||
}
|
||||
|
||||
// Calculates the square root of `i` and rounds it to the next highest
|
||||
// integer value (i.e., the square of the result is guaranteed to be
|
||||
// greater or equal to `i`).
|
||||
static llvm::APInt APIntCeilSqrt(const llvm::APInt &i) {
|
||||
llvm::APInt res = i.sqrt();
|
||||
llvm::APInt resSq = APIntWidthExtendUSq(res);
|
||||
llvm::APInt resSq = APIntWidthExtendUnsignedSq(res);
|
||||
|
||||
if (APIntWidthExtendULT(resSq, i))
|
||||
return APIntWidthExtendUAdd(res, llvm::APInt{1, 1, false});
|
||||
@@ -234,7 +244,7 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp,
|
||||
llvm::APInt accu{1, 0, false};
|
||||
|
||||
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
|
||||
llvm::APInt valSqNorm = APIntWidthExtendUSq(val);
|
||||
llvm::APInt valSqNorm = APIntWidthExtendSqForConstant(val);
|
||||
llvm::APInt mulSqNorm = APIntWidthExtendUMul(valSqNorm, eNorm);
|
||||
accu = APIntWidthExtendUAdd(accu, mulSqNorm);
|
||||
}
|
||||
@@ -258,8 +268,9 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy,
|
||||
|
||||
unsigned elWidth = tTy.getElementTypeBitWidth();
|
||||
|
||||
llvm::APInt maxVal = APInt::getMaxValue(elWidth);
|
||||
llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal);
|
||||
llvm::APInt maxVal = APInt::getSignedMaxValue(elWidth);
|
||||
llvm::APInt maxValSq = APIntWidthExtendUnsignedSq(maxVal);
|
||||
|
||||
llvm::APInt maxMulSqNorm = APIntWidthExtendUMul(maxValSq, eNorm);
|
||||
|
||||
// Calculate number of bits for APInt to store number of elements
|
||||
@@ -272,6 +283,30 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy,
|
||||
return APIntWidthExtendUMul(maxMulSqNorm, nEltsAP);
|
||||
}
|
||||
|
||||
// Returns the squared 2-norm of the maximum value of the dense values.
|
||||
static llvm::APInt maxIntNorm2Sq(mlir::DenseIntElementsAttr denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
|
||||
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
|
||||
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
if (maxCst.ult(iCst)) {
|
||||
maxCst = iCst;
|
||||
}
|
||||
}
|
||||
return APIntWidthExtendSqForConstant(maxCst);
|
||||
}
|
||||
|
||||
// Returns the squared 2-norm for a dynamic integer by conservatively
|
||||
// assuming that the integer's value is the maximum for the integer
|
||||
// width.
|
||||
static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
|
||||
assert(t.isSignlessInteger() && "Type must be a signless integer type");
|
||||
assert(std::numeric_limits<unsigned>::max() - t.getIntOrFloatBitWidth() > 1);
|
||||
|
||||
llvm::APInt maxVal = APInt::getSignedMaxValue(t.getIntOrFloatBitWidth());
|
||||
return APIntWidthExtendUnsignedSq(maxVal);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
// `FHELinalg.dot_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -306,18 +341,6 @@ static llvm::APInt getSqMANP(
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the squared 2-norm for a dynamic integer by conservatively
|
||||
// assuming that the integer's value is the maximum for the integer
|
||||
// width.
|
||||
static llvm::APInt conservativeIntNorm2Sq(mlir::Type t) {
|
||||
assert(t.isSignlessInteger() && "Type must be a signless integer type");
|
||||
assert(std::numeric_limits<unsigned>::max() - t.getIntOrFloatBitWidth() > 1);
|
||||
|
||||
llvm::APInt maxVal{t.getIntOrFloatBitWidth() + 1, 1, false};
|
||||
maxVal <<= t.getIntOrFloatBitWidth();
|
||||
return APIntWidthExtendUSq(maxVal);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of an
|
||||
// `FHE.add_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -343,7 +366,7 @@ static llvm::APInt getSqMANP(
|
||||
if (cstOp) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
sqNorm = APIntWidthExtendUSq(attr.getValue());
|
||||
sqNorm = APIntWidthExtendSqForConstant(attr.getValue());
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
@@ -395,7 +418,7 @@ static llvm::APInt getSqMANP(
|
||||
if (cstOp) {
|
||||
// For constant plaintext operands simply use the constant value
|
||||
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
sqNorm = APIntWidthExtendUSq(attr.getValue());
|
||||
sqNorm = APIntWidthExtendSqForConstant(attr.getValue());
|
||||
} else {
|
||||
// For dynamic plaintext operands conservatively assume that the integer has
|
||||
// its maximum possible value
|
||||
@@ -445,7 +468,7 @@ static llvm::APInt getSqMANP(
|
||||
if (cstOp) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
sqNorm = APIntWidthExtendUSq(attr.getValue());
|
||||
sqNorm = APIntWidthExtendSqForConstant(attr.getValue());
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
@@ -486,14 +509,7 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
if (denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
|
||||
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
|
||||
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
if (maxCst.ult(iCst)) {
|
||||
maxCst = iCst;
|
||||
}
|
||||
}
|
||||
sqNorm = APIntWidthExtendUSq(maxCst);
|
||||
sqNorm = maxIntNorm2Sq(denseVals);
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
@@ -548,15 +564,7 @@ static llvm::APInt getSqMANP(
|
||||
: nullptr;
|
||||
|
||||
if (denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
|
||||
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
|
||||
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
if (maxCst.ult(iCst)) {
|
||||
maxCst = iCst;
|
||||
}
|
||||
}
|
||||
sqNorm = APIntWidthExtendUSq(maxCst);
|
||||
sqNorm = maxIntNorm2Sq(denseVals);
|
||||
} else {
|
||||
// For dynamic plaintext operands conservatively assume that the integer has
|
||||
// its maximum possible value
|
||||
@@ -612,14 +620,7 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
if (denseVals) {
|
||||
// For a constant operand use actual constant to calculate 2-norm
|
||||
llvm::APInt maxCst = denseVals.getFlatValue<llvm::APInt>(0);
|
||||
for (int64_t i = 0; i < denseVals.getNumElements(); i++) {
|
||||
llvm::APInt iCst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
if (maxCst.ult(iCst)) {
|
||||
maxCst = iCst;
|
||||
}
|
||||
}
|
||||
sqNorm = APIntWidthExtendUSq(maxCst);
|
||||
sqNorm = maxIntNorm2Sq(denseVals);
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
@@ -639,7 +640,7 @@ static llvm::APInt computeVectorNorm(
|
||||
elementSelector[axis] = i;
|
||||
|
||||
llvm::APInt weight = denseValues.getValue<llvm::APInt>(elementSelector);
|
||||
llvm::APInt weightNorm = APIntWidthExtendUSq(weight);
|
||||
llvm::APInt weightNorm = APIntWidthExtendSqForConstant(weight);
|
||||
|
||||
llvm::APInt multiplicationNorm =
|
||||
APIntWidthExtendUMul(encryptedOperandNorm, weightNorm);
|
||||
@@ -749,7 +750,7 @@ static llvm::APInt getSqMANP(
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
llvm::APInt cst =
|
||||
denseVals.getValue<llvm::APInt>({(uint64_t)n, (uint64_t)p});
|
||||
llvm::APInt rhsNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst);
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
|
||||
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
|
||||
}
|
||||
@@ -765,7 +766,7 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
llvm::APInt cst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
llvm::APInt rhsNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt rhsNorm = APIntWidthExtendSqForConstant(cst);
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
|
||||
accNorm = APIntWidthExtendUAdd(mulNorm, accNorm);
|
||||
}
|
||||
@@ -849,7 +850,7 @@ static llvm::APInt getSqMANP(
|
||||
for (int64_t n = 0; n < N; n++) {
|
||||
llvm::APInt cst =
|
||||
denseVals.getValue<llvm::APInt>({(uint64_t)m, (uint64_t)n});
|
||||
llvm::APInt lhsNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt lhsNorm = APIntWidthExtendSqForConstant(cst);
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
|
||||
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
|
||||
}
|
||||
@@ -865,7 +866,7 @@ static llvm::APInt getSqMANP(
|
||||
|
||||
for (int64_t i = 0; i < N; i++) {
|
||||
llvm::APInt cst = denseVals.getFlatValue<llvm::APInt>(i);
|
||||
llvm::APInt lhsNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt lhsNorm = APIntWidthExtendSqForConstant(cst);
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(lhsNorm, rhsNorm);
|
||||
accNorm = APIntWidthExtendUAdd(mulNorm, accNorm);
|
||||
}
|
||||
@@ -1106,14 +1107,14 @@ static llvm::APInt getSqMANP(
|
||||
// If there is a bias, start accumulating from its norm
|
||||
if (hasBias && biasDenseVals) {
|
||||
llvm::APInt cst = biasDenseVals.getFlatValue<llvm::APInt>(f);
|
||||
tmpNorm = APIntWidthExtendUSq(cst);
|
||||
tmpNorm = APIntWidthExtendSqForConstant(cst);
|
||||
}
|
||||
for (uint64_t c = 0; c < C; c++) {
|
||||
for (uint64_t h = 0; h < H; h++) {
|
||||
for (uint64_t w = 0; w < W; w++) {
|
||||
llvm::APInt cst =
|
||||
weightDenseVals.getValue<llvm::APInt>({f, c, h, w});
|
||||
llvm::APInt weightNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt weightNorm = APIntWidthExtendSqForConstant(cst);
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
|
||||
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
|
||||
}
|
||||
@@ -1138,7 +1139,7 @@ static llvm::APInt getSqMANP(
|
||||
llvm::APInt maxNorm = tmpNorm;
|
||||
for (uint64_t f = 0; f < F; f++) {
|
||||
llvm::APInt cst = biasDenseVals.getFlatValue<llvm::APInt>(f);
|
||||
llvm::APInt currentNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt currentNorm = APIntWidthExtendSqForConstant(cst);
|
||||
currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm);
|
||||
maxNorm = APIntUMax(currentNorm, maxNorm);
|
||||
}
|
||||
|
||||
@@ -34,7 +34,8 @@ func @single_cst_add_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
func @single_dyn_add_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// sqrt(1 + (2^2-1)^2) = 3.16
|
||||
// CHECK: %[[ret:.*]] = "FHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%0 = "FHE.add_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
|
||||
return %0 : !FHE.eint<2>
|
||||
@@ -66,7 +67,8 @@ func @single_cst_sub_int_eint(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// sqrt(1 + (2^2-1)^2) = 3.16
|
||||
// CHECK: %[[ret:.*]] = "FHE.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (i3, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
%0 = "FHE.sub_int_eint"(%i, %e) : (i3, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
return %0 : !FHE.eint<2>
|
||||
@@ -98,7 +100,8 @@ func @single_cst_mul_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// sqrt(1 + (2^2-1)^2) = 3
|
||||
// CHECK: %[[ret:.*]] = "FHE.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%0 = "FHE.mul_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
|
||||
return %0 : !FHE.eint<2>
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: concretecompiler --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
|
||||
// RUN: concretecompiler --passes canonicalize --passes MANP --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
|
||||
|
||||
func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
@@ -12,9 +12,22 @@ func @single_cst_add_eint_int(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_add_eint_int_from_cst_elements(%t: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
%cst1 = arith.constant 1 : i3
|
||||
%cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3>
|
||||
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.add_eint_int"(%t, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
func @single_dyn_add_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// sqrt(1 + (2^2-1)^2) = 3..16
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.add_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
@@ -44,6 +57,19 @@ func @single_cst_sub_int_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_sub_int_eint_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
%cst1 = arith.constant 1 : i3
|
||||
%cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3>
|
||||
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.sub_int_eint"(%cst, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
@@ -56,7 +82,8 @@ func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
func @single_dyn_sub_int_eint(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 9 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
// sqrt(1 + (2^2-1)^2) = 3.16
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.sub_int_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.sub_int_eint"(%i, %e) : (tensor<8xi3>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
@@ -76,9 +103,23 @@ func @single_cst_mul_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_mul_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
%cst1 = arith.constant 1 : i3
|
||||
%cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3>
|
||||
|
||||
// %0 = "FHELinalg.mul_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.mul_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_mul_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// sqrt(1 * (2^2-1)^2) = 3.16
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.mul_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.mul_eint_int"(%e, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
@@ -86,21 +127,21 @@ func @single_dyn_mul_eint_int(%e: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> t
|
||||
|
||||
// -----
|
||||
|
||||
func @chain_add_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
func @chain_add_eint_int(%e: tensor<8x!FHE.eint<3>>) -> tensor<8x!FHE.eint<3>>
|
||||
{
|
||||
%cst0 = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
|
||||
%cst1 = arith.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi3>
|
||||
%cst2 = arith.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi3>
|
||||
%cst3 = arith.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi3>
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%1 = "FHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%2 = "FHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%3 = "FHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
return %3 : tensor<8x!FHE.eint<2>>
|
||||
%cst0 = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi4>
|
||||
%cst1 = arith.constant dense<[0, 7, 2, 5, 6, 2, 1, 7]> : tensor<8xi4>
|
||||
%cst2 = arith.constant dense<[0, 1, 2, 0, 1, 2, 0, 1]> : tensor<8xi4>
|
||||
%cst3 = arith.constant dense<[0, 1, 1, 0, 0, 1, 0, 1]> : tensor<8xi4>
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
|
||||
%0 = "FHELinalg.add_eint_int"(%e, %cst0) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
|
||||
%1 = "FHELinalg.add_eint_int"(%0, %cst1) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
|
||||
%2 = "FHELinalg.add_eint_int"(%1, %cst2) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "FHELinalg.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
|
||||
%3 = "FHELinalg.add_eint_int"(%2, %cst3) : (tensor<8x!FHE.eint<3>>, tensor<8xi4>) -> tensor<8x!FHE.eint<3>>
|
||||
return %3 : tensor<8x!FHE.eint<3>>
|
||||
}
|
||||
|
||||
// -----
|
||||
@@ -132,7 +173,7 @@ func @apply_lookup_table(%t: tensor<3x3x!FHE.eint<2>>) -> tensor<3x3x!FHE.eint<3
|
||||
|
||||
func @apply_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>) -> tensor<8x!FHE.eint<3>> {
|
||||
%lut = arith.constant dense<[1,3,5,7]> : tensor<4xi64>
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>>
|
||||
%res = "FHELinalg.apply_lookup_table"(%0, %lut) : (tensor<8x!FHE.eint<2>>, tensor<4xi64>) -> tensor<8x!FHE.eint<3>>
|
||||
@@ -151,7 +192,7 @@ func @apply_multi_lookup_table(%t: tensor<3x3x!FHE.eint<2>>, %luts: tensor<3x3x4
|
||||
// -----
|
||||
|
||||
func @apply_multi_lookup_table_after_op(%t: tensor<8x!FHE.eint<2>>, %i: tensor<8xi3>, %luts: tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>> {
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
// CHECK-NEXT: %[[RES:.*]] = "FHELinalg.apply_multi_lookup_table"(%[[V0]], %[[LUT:.*]]) {MANP = 1 : ui1} : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>>
|
||||
%res = "FHELinalg.apply_multi_lookup_table"(%0, %luts) : (tensor<8x!FHE.eint<2>>, tensor<8x4xi64>) -> tensor<8x!FHE.eint<3>>
|
||||
@@ -173,12 +214,13 @@ func @single_cst_dot(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2>
|
||||
return %0 : !FHE.eint<2>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FHE.eint<2>
|
||||
{
|
||||
// sqrt(1*(2^3-1)^2*4) = 14
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
|
||||
// sqrt(1*(2^2-1)^2*4) = 16
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
|
||||
%0 = "FHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
|
||||
|
||||
return %0 : !FHE.eint<2>
|
||||
@@ -188,13 +230,13 @@ func @single_dyn_dot(%t: tensor<4x!FHE.eint<2>>, %dyn: tensor<4xi3>) -> !FHE.ein
|
||||
|
||||
func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2>
|
||||
{
|
||||
// sqrt((2^3)^2*1) = sqrt(64) = 8
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}}
|
||||
// sqrt((2^2-1)^2*1) = sqrt(9) = 3
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}}
|
||||
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
|
||||
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3>
|
||||
// sqrt(1^2*64 + 2^2*64 + 3^2*64 + 4^2*64) = sqrt(1920) = 43.8178046
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 44 : ui{{[[0-9]+}}}
|
||||
%cst = arith.constant dense<[1, 2, 3, -1]> : tensor<4xi3>
|
||||
// sqrt(1^2*9 + 2^2*9 + 3^2*9 + 1^2*9) = sqrt(135) = 12
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 12 : ui{{[[0-9]+}}}
|
||||
%1 = "FHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
|
||||
|
||||
return %1 : !FHE.eint<2>
|
||||
@@ -204,12 +246,12 @@ func @single_cst_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !
|
||||
|
||||
func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !FHE.eint<2>
|
||||
{
|
||||
// sqrt((2^3)^2*1) = sqrt(64) = 8
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}}
|
||||
// sqrt((2^2-1)^2*1) = sqrt(9) = 3
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 3 : ui{{[0-9]+}}}
|
||||
%0 = "FHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
|
||||
// sqrt(4*(2^3-1)^2*64) = sqrt(12544) = 112
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 112 : ui{{[[0-9]+}}}
|
||||
// sqrt(4*(2^2-1)^2*9) = sqrt(324) = 18
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 18 : ui{{[[0-9]+}}}
|
||||
%1 = "FHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> !FHE.eint<2>
|
||||
|
||||
return %1 : !FHE.eint<2>
|
||||
@@ -224,10 +266,10 @@ func @single_dyn_dot_after_op(%t: tensor<4x!FHE.eint<2>>, %i: tensor<4xi3>) -> !
|
||||
func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// ceil(sqrt(65)) = 9
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// ceil(sqrt(65)) = 4
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x1x!FHE.eint<2>>, tensor<1x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -237,13 +279,13 @@ func @matmul_eint_int_dyn_p_1(%arg0: tensor<3x1x!FHE.eint<2>>, %arg1: tensor<1x2
|
||||
func @matmul_eint_int_dyn_p_2(%arg0: tensor<3x2x!FHE.eint<2>>, %arg1: tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 9 + 1 = 10
|
||||
// p = 1
|
||||
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 65 = 129
|
||||
// ceil(sqrt(129)) = 12
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}}
|
||||
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 10 + 9 = 19
|
||||
// ceil(sqrt(19)) = 5
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_eint_int"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_eint_int"(%arg0, %arg1): (tensor<3x2x!FHE.eint<2>>, tensor<2x2xi3>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -507,10 +549,10 @@ func @matmul_eint_int_cst_different_operand_manp() -> tensor<4x3x!FHE.eint<7>> {
|
||||
func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// ceil(sqrt(65)) = 9
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 10
|
||||
// ceil(sqrt(65)) = 4
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x1xi3>, tensor<1x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -520,13 +562,13 @@ func @matmul_int_eint_dyn_p_1(%arg0: tensor<3x1xi3>, %arg1: tensor<1x2x!FHE.eint
|
||||
func @matmul_int_eint_dyn_p_2(%arg0: tensor<3x2xi3>, %arg1: tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>> {
|
||||
// p = 0
|
||||
// acc = manp(0) = 1
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 65
|
||||
// mul = manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 64 + 1 = 10
|
||||
// p = 1
|
||||
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^3)^2 = 64
|
||||
// manp(add_eint(mul, acc)) = 64 + 65 = 129
|
||||
// manp(mul_eint_int(eint<2>, i3) = 1 * (2^2-1)^2 = 9
|
||||
// manp(add_eint(mul, acc)) = 10 + 9 = 19
|
||||
// ceil(sqrt(129)) = 12
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 12 : ui{{[0-9]+}}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.matmul_int_eint"(%[[A0:.*]], %[[A1:.*]]) {MANP = 5 : ui{{[0-9]+}}}
|
||||
%1 = "FHELinalg.matmul_int_eint"(%arg0, %arg1): (tensor<3x2xi3>, tensor<2x2x!FHE.eint<2>>) -> tensor<3x2x!FHE.eint<2>>
|
||||
return %1 : tensor<3x2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -969,7 +1011,7 @@ func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> te
|
||||
|
||||
func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 129 : ui{{[0-9]+}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 64 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
@@ -980,7 +1022,7 @@ func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1
|
||||
|
||||
func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
|
||||
%bias = arith.constant dense<[5]> : tensor<1xi3>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 17 : ui{{[0-9]+}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 7 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
|
||||
@@ -990,7 +1032,7 @@ func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x
|
||||
// -----
|
||||
|
||||
func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 18 : ui{{[0-9]+}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 7 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
|
||||
@@ -1000,7 +1042,7 @@ func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: te
|
||||
// -----
|
||||
|
||||
func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> {
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 29 : ui{{[0-9]+}}
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 11 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>>
|
||||
|
||||
@@ -40,17 +40,14 @@ func @tensor_extract_1(%t: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2>
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_extract_2(%a: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @tensor_extract_2(%a: tensor<4x!FHE.eint<2>>) -> !FHE.eint<2>
|
||||
{
|
||||
%c1 = arith.constant 1 : index
|
||||
%c3 = arith.constant 3 : i3
|
||||
|
||||
// CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>>
|
||||
%1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!FHE.eint<2>>
|
||||
// CHECK: %[[ret:.*]] = tensor.extract %[[V1]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>>
|
||||
%2 = tensor.extract %1[%c1] : tensor<4x!FHE.eint<2>>
|
||||
%c3 = arith.constant dense<3> : tensor<4xi3>
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%0 = "FHELinalg.add_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
// CHECK: %[[ret:.*]] = tensor.extract %[[V0]][%[[c3:.*]]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>>
|
||||
%2 = tensor.extract %0[%c1] : tensor<4x!FHE.eint<2>>
|
||||
|
||||
return %2 : !FHE.eint<2>
|
||||
}
|
||||
@@ -67,16 +64,15 @@ func @tensor_extract_slice_1(%t: tensor<2x10x!FHE.eint<2>>) -> tensor<1x5x!FHE.e
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_extract_slice_2(%a: !FHE.eint<2>) -> tensor<2x!FHE.eint<2>>
|
||||
func @tensor_extract_slice_2(%a: tensor<4x!FHE.eint<2>>) -> tensor<2x!FHE.eint<2>>
|
||||
{
|
||||
%c3 = arith.constant 3 : i3
|
||||
%c3 = arith.constant dense <3> : tensor<4xi3>
|
||||
|
||||
// CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK: %[[V1:.*]] = tensor.from_elements %[[V0]], %[[a:.*]], %[[a:.*]], %[[a:.*]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<2>>
|
||||
%1 = tensor.from_elements %0, %a, %a, %a : tensor<4x!FHE.eint<2>>
|
||||
// CHECK: tensor.extract_slice %[[V1]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>>
|
||||
%2 = tensor.extract_slice %1[2] [2] [1] : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>>
|
||||
// CHECK: %[[V0:.*]] = "FHELinalg.add_eint_int"(%[[a:.*]], %[[c1:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%0 = "FHELinalg.add_eint_int"(%a, %c3) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
|
||||
// CHECK: tensor.extract_slice %[[V0]][2] [2] [1] {MANP = 4 : ui{{[0-9]+}}} : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>>
|
||||
%2 = tensor.extract_slice %0[2] [2] [1] : tensor<4x!FHE.eint<2>> to tensor<2x!FHE.eint<2>>
|
||||
|
||||
return %2 : tensor<2x!FHE.eint<2>>
|
||||
}
|
||||
@@ -93,36 +89,6 @@ func @tensor_insert_slice_1(%t0: tensor<2x10x!FHE.eint<2>>, %t1: tensor<2x2x!FHE
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_insert_slice_2(%a: !FHE.eint<5>) -> tensor<4x!FHE.eint<5>>
|
||||
{
|
||||
%c3 = arith.constant 3 : i6
|
||||
%c6 = arith.constant 6 : i6
|
||||
|
||||
// CHECK: %[[V0:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c3:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<5>, i6) -> !FHE.eint<5>
|
||||
%v0 = "FHE.add_eint_int"(%a, %c3) : (!FHE.eint<5>, i6) -> !FHE.eint<5>
|
||||
// CHECK: %[[V1:.*]] = "FHE.add_eint_int"(%[[a:.*]], %[[c6:.*]]) {MANP = 7 : ui{{[0-9]+}}} : (!FHE.eint<5>, i6) -> !FHE.eint<5>
|
||||
%v1 = "FHE.add_eint_int"(%a, %c6) : (!FHE.eint<5>, i6) -> !FHE.eint<5>
|
||||
|
||||
// CHECK: %[[T0:.*]] = tensor.from_elements %[[V0]], %[[V0]], %[[V0]], %[[V0]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<4x!FHE.eint<5>>
|
||||
%t0 = tensor.from_elements %v0, %v0, %v0, %v0 : tensor<4x!FHE.eint<5>>
|
||||
|
||||
// CHECK: %[[T1:.*]] = tensor.from_elements %[[V1]], %[[V1]] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>>
|
||||
%t1 = tensor.from_elements %v1, %v1 : tensor<2x!FHE.eint<5>>
|
||||
|
||||
// CHECK: %[[T2:.*]] = tensor.insert_slice %[[T1]] into %[[T0]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>>
|
||||
%t2 = tensor.insert_slice %t1 into %t0[0] [2] [1] : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>>
|
||||
|
||||
// CHECK: %[[T3:.*]] = tensor.from_elements %[[V0]], %[[V0]] {MANP = 4 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>>
|
||||
%t3 = tensor.from_elements %v0, %v0 : tensor<2x!FHE.eint<5>>
|
||||
|
||||
// CHECK: %[[T4:.*]] = tensor.insert_slice %[[T3]] into %[[T2]][0] [2] [1] {MANP = 7 : ui{{[[0-9]+}}} : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>>
|
||||
%t4 = tensor.insert_slice %t3 into %t2[0] [2] [1] : tensor<2x!FHE.eint<5>> into tensor<4x!FHE.eint<5>>
|
||||
|
||||
return %t0 : tensor<4x!FHE.eint<5>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8x!FHE.eint<6>> {
|
||||
// CHECK: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 1 : ui{{[0-9]+}}}
|
||||
%0 = linalg.tensor_collapse_shape %a [[0],[1,2]] : tensor<2x2x4x!FHE.eint<6>> into tensor<2x8x!FHE.eint<6>>
|
||||
@@ -133,9 +99,9 @@ func @tensor_collapse_shape_1(%a: tensor<2x2x4x!FHE.eint<6>>) -> tensor<2x8x!FHE
|
||||
|
||||
func @tensor_collapse_shape_2(%a: tensor<2x2x4x!FHE.eint<2>>, %b: tensor<2x2x4xi3>) -> tensor<2x8x!FHE.eint<2>>
|
||||
{
|
||||
// CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
// CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x2x4x!FHE.eint<2>>, tensor<2x2x4xi3>) -> tensor<2x2x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}}
|
||||
// CHECK-NEXT: linalg.tensor_collapse_shape %[[A:.*]] [[X:.*]] {MANP = 4 : ui{{[0-9]+}}}
|
||||
%1 = linalg.tensor_collapse_shape %0 [[0],[1,2]] : tensor<2x2x4x!FHE.eint<2>> into tensor<2x8x!FHE.eint<2>>
|
||||
return %1 : tensor<2x8x!FHE.eint<2>>
|
||||
}
|
||||
@@ -152,9 +118,9 @@ func @tensor_expand_shape_1(%a: tensor<2x8x!FHE.eint<6>>) -> tensor<2x2x4x!FHE.e
|
||||
|
||||
func @tensor_expand_shape_2(%a: tensor<2x8x!FHE.eint<2>>, %b: tensor<2x8xi3>) -> tensor<2x2x4x!FHE.eint<2>>
|
||||
{
|
||||
// CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 9 : ui{{[0-9]+}}}
|
||||
// CHECK: "FHELinalg.add_eint_int"(%[[A:.*]], %[[B:.*]]) {MANP = 4 : ui{{[0-9]+}}}
|
||||
%0 = "FHELinalg.add_eint_int"(%a, %b) : (tensor<2x8x!FHE.eint<2>>, tensor<2x8xi3>) -> tensor<2x8x!FHE.eint<2>>
|
||||
// CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 9 : ui{{[0-9]+}}}
|
||||
// CHECK-NEXT: linalg.tensor_expand_shape %[[A:.*]] [[X:.*]] {MANP = 4 : ui{{[0-9]+}}}
|
||||
%1 = linalg.tensor_expand_shape %0 [[0],[1,2]] : tensor<2x8x!FHE.eint<2>> into tensor<2x2x4x!FHE.eint<2>>
|
||||
return %1 : tensor<2x2x4x!FHE.eint<2>>
|
||||
}
|
||||
Reference in New Issue
Block a user