feat: implement all kinds of subtractions

This commit is contained in:
Umut
2022-06-23 12:28:04 +02:00
parent 8f8a57d220
commit b3a2671dc7
15 changed files with 1200 additions and 4 deletions

View File

@@ -148,6 +148,71 @@ def SubIntEintOp : FHE_Op<"sub_int_eint"> {
let hasVerifier = 1;
}
def SubEintIntOp : FHE_Op<"sub_eint_int"> {
let summary = "Substract a clear integer from an encrypted integer";
let description = [{
Substract a clear integer from an encrypted integer.
The clear integer must have at most one more bit than the encrypted integer
and the result must have the same width than the encrypted integer.
Example:
```mlir
// ok
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// error
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i4) -> !FHE.eint<2>
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<3>
```
}];
let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b);
let results = (outs EncryptedIntegerType);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
build($_builder, $_state, a.getType(), a, b);
}]>
];
let hasVerifier = 1;
let hasFolder = 1;
}
def SubEintOp : FHE_Op<"sub_eint"> {
let summary = "Subtracts two encrypted integers";
let description = [{
Subtracts two encrypted integers
The encrypted integers and the result must have the same width.
Example:
```mlir
// ok
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
// error
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
```
}];
let arguments = (ins EncryptedIntegerType:$a, EncryptedIntegerType:$b);
let results = (outs EncryptedIntegerType);
let builders = [
OpBuilder<(ins "Value":$a, "Value":$b), [{
build($_builder, $_state, a.getType(), a, b);
}]>
];
let hasVerifier = 1;
}
def NegEintOp : FHE_Op<"neg_eint"> {
let summary = "Negates an encrypted integer";
@@ -224,7 +289,7 @@ def ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table"> {
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<2>)
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<3>, tensor<4xi64>) -> (!FHE.eint<2>)
// error
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<8xi64>) -> (!FHE.eint<2>)
```

View File

@@ -181,6 +181,114 @@ def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tensor
];
}
def SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
let summary = "Returns a tensor that contains the substraction of a tensor of clear integers from a tensor of encrypted integers.";
let description = [{
Performs a substraction following the broadcasting rules between a tensor of clear integers from a tensor of encrypted integers.
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
Examples:
```mlir
// Returns the term to term substraction of `%a0` with `%a1`
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
// Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<1x4x4x!FHE.eint<4>>, tensor<4x1x4xi5>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
//
// [1,2,3] [1] [0,2,3]
// [4,5,6] - [2] = [2,3,4]
// [7,8,9] [3] [4,5,6]
//
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x1x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
//
// [1,2,3] [0,0,0]
// [4,5,6] - [1,2,3] = [3,3,3]
// [7,8,9] [6,6,6]
//
// The dimension #2 of operand #2 is stretched as it is equals to 1.
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<1x3x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>>
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>>
```
}];
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
build($_builder, $_state, lhs.getType(), lhs, rhs);
}]>
];
let hasFolder = 1;
}
def SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
let summary = "Returns a tensor that contains the subtraction of two tensor of encrypted integers.";
let description = [{
Performs an subtraction follwing the broadcasting rules between two tensors of encrypted integers.
The width of the encrypted integers must be equal.
Examples:
```mlir
// Returns the term to term subtraction of `%a0` with `%a1`
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
// Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
//
// [1,2,3] [1] [0,2,3]
// [4,5,6] - [2] = [2,3,4]
// [7,8,9] [3] [4,5,6]
//
// The dimension #1 of operand #2 is stretched as it is equals to 1.
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
//
// [1,2,3] [0,0,0]
// [4,5,6] - [1,2,3] = [3,3,3]
// [7,8,9] [6,6,6]
//
// The dimension #2 of operand #2 is stretched as it is equals to 1.
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
// Same behavior than the previous one, but as the dimension #2 of operand #2 is missing.
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
```
}];
let arguments = (ins
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$rhs
);
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let builders = [
OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
build($_builder, $_state, lhs.getType(), lhs, rhs);
}]>
];
}
def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers.";

View File

@@ -1613,6 +1613,14 @@ void FHETensorOpsToLinalg::runOnOperation() {
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::SubIntEintOp,
mlir::concretelang::FHE::SubIntEintOp>>(
&getContext());
patterns.insert<
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::SubEintIntOp,
mlir::concretelang::FHE::SubEintIntOp>>(
&getContext());
patterns.insert<
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::SubEintOp,
mlir::concretelang::FHE::SubEintOp>>(
&getContext());
patterns.insert<
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::MulEintIntOp,
mlir::concretelang::FHE::MulEintIntOp>>(

View File

@@ -115,6 +115,85 @@ struct ApplyLookupTableEintOpPattern
};
};
// This rewrite pattern transforms any instance of `FHE.sub_eint_int`
// operators to a negation and an addition.
struct SubEintIntOpPattern : public mlir::OpRewritePattern<FHE::SubEintIntOp> {
SubEintIntOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<FHE::SubEintIntOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubEintIntOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhs = op.getOperand(0);
mlir::Value rhs = op.getOperand(1);
mlir::Type rhsType = rhs.getType();
mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(rhsType, -1);
mlir::Value minusOne =
rewriter.create<mlir::arith::ConstantOp>(location, minusOneAttr)
.getResult();
mlir::Value negative =
rewriter.create<mlir::arith::MulIOp>(location, rhs, minusOne)
.getResult();
FHEToTFHETypeConverter converter;
auto resultTy = converter.convertType(op.getType());
auto addition =
rewriter.create<TFHE::AddGLWEIntOp>(location, resultTy, lhs, negative);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
rewriter.replaceOp(op, {addition.getResult()});
return mlir::success();
};
};
// This rewrite pattern transforms any instance of `FHE.sub_eint`
// operators to a negation and an addition.
struct SubEintOpPattern : public mlir::OpRewritePattern<FHE::SubEintOp> {
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<FHE::SubEintOp>(context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(FHE::SubEintOp op,
mlir::PatternRewriter &rewriter) const override {
mlir::Location location = op.getLoc();
mlir::Value lhs = op.getOperand(0);
mlir::Value rhs = op.getOperand(1);
FHEToTFHETypeConverter converter;
auto rhsTy = converter.convertType(rhs.getType());
auto negative = rewriter.create<TFHE::NegGLWEOp>(location, rhsTy, rhs);
mlir::concretelang::convertOperandAndResultTypes(
rewriter, negative, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
auto resultTy = converter.convertType(op.getType());
auto addition = rewriter.create<TFHE::AddGLWEOp>(location, resultTy, lhs,
negative.getResult());
mlir::concretelang::convertOperandAndResultTypes(
rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) {
return converter.convertType(t);
});
rewriter.replaceOp(op, {addition.getResult()});
return mlir::success();
};
};
void FHEToTFHEPass::runOnOperation() {
auto op = this->getOperation();
@@ -123,6 +202,7 @@ void FHEToTFHEPass::runOnOperation() {
// Mark ops from the target dialect as legal operations
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
// Make sure that no ops from `FHE` remain after the lowering
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
@@ -155,6 +235,9 @@ void FHEToTFHEPass::runOnOperation() {
patterns.getContext(), converter);
patterns.add<ApplyLookupTableEintOpPattern>(&getContext());
patterns.add<SubEintOpPattern>(&getContext());
patterns.add<SubEintIntOpPattern>(&getContext());
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
FHEToTFHETypeConverter>>(
&getContext(), converter);

View File

@@ -429,6 +429,57 @@ static llvm::APInt getSqMANP(
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `FHE.sub_eint_int` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::SubEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::Type iTy = op->getOpOperand(1).get().getType();
assert(iTy.isSignlessInteger() &&
"Only subtractions with signless integers are currently allowed");
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt sqNorm;
mlir::arith::ConstantOp cstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
if (cstOp) {
// For constant plaintext operands simply use the constant value
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
sqNorm = APIntWidthExtendSqForConstant(attr.getValue());
} else {
// For dynamic plaintext operands conservatively assume that the integer has
// its maximum possible value
sqNorm = conservativeIntNorm2Sq(iTy);
}
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `FHE.sub_eint` operation.
static llvm::APInt getSqMANP(
mlir::concretelang::FHE::SubEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
return APIntWidthExtendUAdd(a, b);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `FHE.neg_eint` operation.
static llvm::APInt getSqMANP(
@@ -575,6 +626,58 @@ static llvm::APInt getSqMANP(
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::SubEintIntOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::RankedTensorType op1Ty =
op->getOpOperand(1).get().getType().cast<mlir::RankedTensorType>();
mlir::Type iTy = op1Ty.getElementType();
assert(iTy.isSignlessInteger() &&
"Only subtractions with signless integers are currently allowed");
assert(
operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt sqNorm;
mlir::arith::ConstantOp cstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
mlir::DenseIntElementsAttr denseVals =
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
: nullptr;
if (denseVals) {
sqNorm = maxIntNorm2Sq(denseVals);
} else {
// For dynamic plaintext operands conservatively assume that the integer has
// its maximum possible value
sqNorm = conservativeIntNorm2Sq(iTy);
}
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::SubEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
operandMANPs[1]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
return APIntWidthExtendUAdd(a, b);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `FHELinalg.neg_eint` operation.
static llvm::APInt getSqMANP(
@@ -1192,6 +1295,12 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (auto subIntEintOp =
llvm::dyn_cast<mlir::concretelang::FHE::SubIntEintOp>(op)) {
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
} else if (auto subEintIntOp =
llvm::dyn_cast<mlir::concretelang::FHE::SubEintIntOp>(op)) {
norm2SqEquiv = getSqMANP(subEintIntOp, operands);
} else if (auto subEintOp =
llvm::dyn_cast<mlir::concretelang::FHE::SubEintOp>(op)) {
norm2SqEquiv = getSqMANP(subEintOp, operands);
} else if (auto negEintOp =
llvm::dyn_cast<mlir::concretelang::FHE::NegEintOp>(op)) {
norm2SqEquiv = getSqMANP(negEintOp, operands);
@@ -1219,6 +1328,14 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
llvm::dyn_cast<mlir::concretelang::FHELinalg::SubIntEintOp>(
op)) {
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
} else if (auto subEintIntOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::SubEintIntOp>(
op)) {
norm2SqEquiv = getSqMANP(subEintIntOp, operands);
} else if (auto subEintOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::SubEintOp>(
op)) {
norm2SqEquiv = getSqMANP(subEintOp, operands);
} else if (auto negEintOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::NegEintOp>(
op)) {

View File

@@ -89,6 +89,35 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::Operation &op,
return ::mlir::success();
}
::mlir::LogicalResult SubEintIntOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
auto b = this->b().getType().cast<IntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(),
a, b)) {
return ::mlir::failure();
}
return ::mlir::success();
}
::mlir::LogicalResult SubEintOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
auto b = this->b().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
out)) {
return ::mlir::failure();
}
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) {
return ::mlir::failure();
}
return ::mlir::success();
}
::mlir::LogicalResult NegEintOp::verify() {
auto a = this->a().getType().cast<EncryptedIntegerType>();
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
@@ -147,6 +176,19 @@ OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
return nullptr;
}
// Avoid subtraction with constant 0
OpFoldResult SubEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto toSub = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
if (toSub != nullptr) {
auto intToSub = toSub.getInt();
if (intToSub == 0) {
return getOperand(0);
}
}
return nullptr;
}
// Avoid multiplication with constant 1
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);

View File

@@ -1731,6 +1731,20 @@ OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
return getOperand(0);
}
// Avoid subtraction with constant tensor of 0s
OpFoldResult SubEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);
auto toSub = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
if (toSub == nullptr)
return nullptr;
for (auto it = toSub.begin(); it != toSub.end(); it++) {
if (*it != 0) {
return nullptr;
}
}
return getOperand(0);
}
// Avoid multiplication with constant tensor of 1s
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2);

View File

@@ -100,6 +100,60 @@ func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
// -----
func @single_cst_sub_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2>
{
%cst = arith.constant 3 : i3
// CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
%0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
return %0 : !FHE.eint<2>
}
// -----
func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2>
{
%cst = arith.constant -3 : i3
// CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
%0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
return %0 : !FHE.eint<2>
}
// -----
func @single_dyn_sub_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
{
// sqrt(1 + (2^2-1)^2) = 3.16
// CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
%0 = "FHE.sub_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
return %0 : !FHE.eint<2>
}
// -----
func @chain_sub_eint(%e0: !FHE.eint<2>, %e1: !FHE.eint<2>, %e2: !FHE.eint<2>, %e3: !FHE.eint<2>, %e4: !FHE.eint<2>) -> !FHE.eint<2>
{
// CHECK: %[[V0:.*]] = "FHE.sub_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
%0 = "FHE.sub_eint"(%e0, %e1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%[[V0]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
%1 = "FHE.sub_eint"(%0, %e2) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint"(%[[V1]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
%2 = "FHE.sub_eint"(%1, %e3) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: %[[V3:.*]] = "FHE.sub_eint"(%[[V2]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
%3 = "FHE.sub_eint"(%2, %e4) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
return %3 : !FHE.eint<2>
}
// -----
func @single_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2>
{
// CHECK: %[[ret:.*]] = "FHE.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>) -> !FHE.eint<2>
@@ -147,7 +201,7 @@ func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
// -----
func @single_apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<2> {
// CHECK: %[[ret:.*]] = "FHE.apply_lookup_table"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
// CHECK: %[[ret:.*]] = "FHE.apply_lookup_table"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
return %1: !FHE.eint<2>
}

View File

@@ -70,6 +70,41 @@ func @single_cst_sub_int_eint_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> t
// -----
func @single_cst_sub_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
%cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
// CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
}
// -----
func @single_cst_sub_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
%cst1 = arith.constant 1 : i3
%cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3>
// CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
}
// -----
func @single_sub_eint(%e0: tensor<8x!FHE.eint<2>>, %e1: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "FHELinalg.sub_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
%0 = "FHELinalg.sub_eint"(%e0, %e1) : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
return %0 : tensor<8x!FHE.eint<2>>
}
// -----
func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
{
// CHECK: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>

View File

@@ -9,6 +9,15 @@ func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
%0 = arith.constant 0 : i3
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: return %arg0 : !FHE.eint<2>

View File

@@ -49,6 +49,26 @@ func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
// CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint_int"(%arg0, %[[V1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
// CHECK-NEXT: return %[[V2]] : !FHE.eint<2>
%0 = arith.constant 1 : i3
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2>
func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
// CHECK-NEXT: return %[[V1]] : !FHE.eint<2>
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
return %1: !FHE.eint<2>
}
// CHECK-LABEL: func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
// CHECK-NEXT: %[[V1:.*]] = "FHE.neg_eint"(%arg0) : (!FHE.eint<2>) -> !FHE.eint<2>

View File

@@ -27,6 +27,33 @@ func @add_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FH
return %1: tensor<4x3x!FHE.eint<2>>
}
// CHECK: func @sub_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3>
%1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// CHECK: func @sub_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%a1 = arith.constant dense<[0]> : tensor<1xi3>
%1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// CHECK: func @sub_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
%a1 = arith.constant dense<[[0]]> : tensor<1x1xi3>
%1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>>
return %1: tensor<4x3x!FHE.eint<2>>
}
// CHECK: func @mul_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }

View File

@@ -110,7 +110,7 @@ func @add_eint_broadcast_2(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<3x4x!FHE.ein
/////////////////////////////////////////////////
// FHELinalg.sub_eint_int
// FHELinalg.sub_int_eint
/////////////////////////////////////////////////
// 1D tensor
@@ -164,6 +164,116 @@ func @sub_int_eint_broadcast_2(%a0: tensor<3x4xi3>, %a1: tensor<4x!FHE.eint<2>>)
}
/////////////////////////////////////////////////
// FHELinalg.sub_eint_int
/////////////////////////////////////////////////
// 1D tensor
// CHECK: func @sub_eint_int_1D(%[[a0:.*]]: tensor<4xi3>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_int_1D(%a0: tensor<4xi3>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// 2D tensor
// CHECK: func @sub_eint_int_2D(%[[a0:.*]]: tensor<2x4xi3>, %[[a1:.*]]: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<2x4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_int_2D(%a0: tensor<2x4xi3>, %a1: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
%1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!FHE.eint<2>>
return %1: tensor<2x4x!FHE.eint<2>>
}
// 10D tensor
// CHECK: func @sub_eint_int_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10xi3>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_int_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10xi3>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> {
%1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
return %1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
}
// Broadcasting with tensor with dimensions equals to one
// CHECK: func @sub_eint_int_broadcast_1(%[[a0:.*]]: tensor<3x4x1xi3>, %[[a1:.*]]: tensor<1x4x5x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_int_broadcast_1(%a0: tensor<3x4x1xi3>, %a1: tensor<1x4x5x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> {
%1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!FHE.eint<2>>
return %1: tensor<3x4x5x!FHE.eint<2>>
}
// Broadcasting with a tensor less dimensions of another
// CHECK: func @sub_eint_int_broadcast_2(%[[a0:.*]]: tensor<3x4xi3>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<4x!FHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<3x4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_int_broadcast_2(%a0: tensor<3x4xi3>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> {
%1 ="FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!FHE.eint<2>>
return %1: tensor<3x4x!FHE.eint<2>>
}
/////////////////////////////////////////////////
// FHELinalg.sub_eint
/////////////////////////////////////////////////
// 1D tensor
// CHECK: func @sub_eint_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_1D(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
%1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>>
return %1: tensor<4x!FHE.eint<2>>
}
// 2D tensor
// CHECK: func @sub_eint_2D(%[[a0:.*]]: tensor<2x4x!FHE.eint<2>>, %[[a1:.*]]: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<2x4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_2D(%a0: tensor<2x4x!FHE.eint<2>>, %a1: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
%1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>>
return %1: tensor<2x4x!FHE.eint<2>>
}
// 10D tensor
// CHECK: func @sub_eint_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> {
%1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
return %1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
}
// Broadcasting with tensor with dimensions equals to one
// CHECK: func @sub_eint_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!FHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_broadcast_1(%a0: tensor<1x4x5x!FHE.eint<2>>, %a1: tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> {
%1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>>
return %1: tensor<3x4x5x!FHE.eint<2>>
}
// Broadcasting with a tensor less dimensions of another
// CHECK: func @sub_eint_broadcast_2(%[[a0:.*]]: tensor<4x!FHE.eint<2>>, %[[a1:.*]]: tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<4x!FHE.eint<2>>, tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<3x4x!FHE.eint<2>>
// CHECK-NEXT: }
func @sub_eint_broadcast_2(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> {
%1 ="FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>>
return %1: tensor<3x4x!FHE.eint<2>>
}
/////////////////////////////////////////////////
// FHELinalg.neg_eint
/////////////////////////////////////////////////

View File

@@ -63,7 +63,7 @@ tests:
outputs:
- scalar: 3
---
description: sub_eint_int_cst
description: sub_int_eint_cst
program: |
func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
%0 = arith.constant 7 : i3
@@ -88,6 +88,69 @@ tests:
outputs:
- scalar: 3
---
description: sub_eint_int_cst
program: |
func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = arith.constant 7 : i6
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<5>, i6) -> (!FHE.eint<5>)
return %1: !FHE.eint<5>
}
tests:
- inputs:
- scalar: 10
outputs:
- scalar: 3
- inputs:
- scalar: 7
outputs:
- scalar: 0
---
description: sub_eint_int_arg
program: |
func @main(%arg0: !FHE.eint<4>, %arg1: i5) -> !FHE.eint<4> {
%1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<4>, i5) -> (!FHE.eint<4>)
return %1: !FHE.eint<4>
}
tests:
- inputs:
- scalar: 2
- scalar: 2
outputs:
- scalar: 0
- inputs:
- scalar: 3
- scalar: 1
outputs:
- scalar: 2
- inputs:
- scalar: 7
- scalar: 4
outputs:
- scalar: 3
---
description: sub_eint
program: |
func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<4>, !FHE.eint<4>) -> (!FHE.eint<4>)
return %1: !FHE.eint<4>
}
tests:
- inputs:
- scalar: 2
- scalar: 2
outputs:
- scalar: 0
- inputs:
- scalar: 3
- scalar: 1
outputs:
- scalar: 2
- inputs:
- scalar: 7
- scalar: 4
outputs:
- scalar: 3
---
description: sub_int_eint_arg
program: |
func @main(%arg0: i3, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {

View File

@@ -791,6 +791,447 @@ TEST(End2EndJit_FHELinalg, sub_int_eint_matrix_line_missing_dim) {
}
}
///////////////////////////////////////////////////////////////////////////////
// FHELinalg sub_eint_int ///////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
return %res : tensor<4x!FHE.eint<4>>
}
)XXX");
std::vector<uint8_t> a0{31, 6, 2, 3};
std::vector<uint8_t> a1{32, 9, 12, 9};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(a0);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(a1);
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)4);
for (size_t i = 0; i < 4; i++) {
EXPECT_EQ((*res)[i], (uint64_t)a1[i] - a0[i]);
}
}
TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term_broadcast) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<4x1x4xi8>, %a1: tensor<1x4x4x!FHE.eint<7>>) -> tensor<4x4x4x!FHE.eint<7>> {
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x4x4x!FHE.eint<7>>, tensor<4x1x4xi8>) -> tensor<4x4x4x!FHE.eint<7>>
return %res : tensor<4x4x4x!FHE.eint<7>>
}
)XXX");
const uint8_t a0[4][1][4]{
{{1, 2, 3, 4}},
{{5, 6, 7, 8}},
{{9, 10, 11, 12}},
{{13, 14, 15, 16}},
};
const uint8_t a1[1][4][4]{
{
{1, 2, 3, 4},
{5, 6, 7, 8},
{9, 10, 11, 12},
{13, 14, 15, 16},
},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 4 * 1 * 4), {4, 1, 4});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 1 * 4 * 4), {1, 4, 4});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)4 * 4 * 4);
for (size_t i = 0; i < 4; i++) {
for (size_t j = 0; j < 4; j++) {
for (size_t k = 0; k < 4; k++) {
uint8_t expected = a1[i][0][k] - a0[0][j][k];
EXPECT_EQ((*res)[i * 16 + j * 4 + k], (uint64_t)expected);
}
}
}
}
TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_column) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x1xi5>) ->
tensor<3x3x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!FHE.eint<4>>
return %res : tensor<3x3x!FHE.eint<4>>
}
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[3][1]{
{1},
{2},
{3},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 1), {3, 1});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[i][0]);
}
}
}
TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_line) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<1x3xi5>) ->
tensor<3x3x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!FHE.eint<4>>
return %res : tensor<3x3x!FHE.eint<4>>
}
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[1][3]{
{1, 2, 3},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 1), {1, 3});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]);
}
}
}
TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_line_missing_dim) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3xi5>) -> tensor<3x3x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!FHE.eint<4>>
return %res : tensor<3x3x!FHE.eint<4>>
}
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[1][3]{
{1, 2, 3},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 1), {3});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]);
}
}
}
///////////////////////////////////////////////////////////////////////////////
// FHELinalg add_eint ///////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_FHELinalg, sub_eint_term_to_term) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
return %res : tensor<4x!FHE.eint<6>>
}
)XXX");
std::vector<uint8_t> a0{31, 6, 12, 9};
std::vector<uint8_t> a1{4, 2, 9, 3};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(a0);
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(a1);
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)4);
for (size_t i = 0; i < 4; i++) {
EXPECT_EQ((*res)[i], (uint64_t)a0[i] - a1[i])
<< "result differ at pos " << i << ", expect " << a0[i] + a1[i]
<< " got " << (*res)[i];
}
}
TEST(End2EndJit_FHELinalg, sub_eint_term_to_term_broadcast) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<4x1x4x!FHE.eint<5>>, %a1: tensor<1x4x4x!FHE.eint<5>>) -> tensor<4x4x4x!FHE.eint<5>> {
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<5>>, tensor<1x4x4x!FHE.eint<5>>) ->
tensor<4x4x4x!FHE.eint<5>> return %res : tensor<4x4x4x!FHE.eint<5>>
}
)XXX");
uint8_t a0[4][1][4]{
{{10, 20, 30, 40}},
{{5, 6, 7, 8}},
{{9, 10, 11, 12}},
{{13, 14, 15, 16}},
};
uint8_t a1[1][4][4]{
{
{1, 2, 3, 4},
{4, 3, 2, 1},
{3, 1, 4, 2},
{2, 4, 1, 3},
},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::MutableArrayRef<uint8_t>((uint8_t *)a0, 4 * 1 * 4), {4, 1, 4});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::MutableArrayRef<uint8_t>((uint8_t *)a1, 1 * 4 * 4), {1, 4, 4});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)4 * 4 * 4);
for (size_t i = 0; i < 4; i++) {
for (size_t j = 0; j < 4; j++) {
for (size_t k = 0; k < 4; k++) {
EXPECT_EQ((*res)[i * 16 + j * 4 + k],
(uint64_t)a0[i][0][k] - a1[0][j][k])
<< "result differ at pos " << i << "," << j << "," << k;
}
}
}
}
TEST(End2EndJit_FHELinalg, sub_eint_matrix_column) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
return %res : tensor<3x3x!FHE.eint<4>>
}
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[3][1]{
{1},
{2},
{3},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 1), {3, 1});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[i][0]);
}
}
}
TEST(End2EndJit_FHELinalg, sub_eint_matrix_line) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
return %res : tensor<3x3x!FHE.eint<4>>
}
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[1][3]{
{1, 2, 3},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 1), {1, 3});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]);
}
}
}
TEST(End2EndJit_FHELinalg, sub_eint_matrix_line_missing_dim) {
checkedJit(lambda, R"XXX(
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> {
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
return %res : tensor<3x3x!FHE.eint<4>>
}
)XXX");
const uint8_t a0[3][3]{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9},
};
const uint8_t a1[1][3]{
{1, 2, 3},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 1), {3});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 3; j++) {
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]);
}
}
}
TEST(End2EndJit_FHELinalg, sub_eint_tensor_dim_equals_1) {
checkedJit(lambda, R"XXX(
func @main(%arg0: tensor<3x1x2x!FHE.eint<5>>, %arg1: tensor<3x1x2x!FHE.eint<5>>) -> tensor<3x1x2x!FHE.eint<5>> {
%1 = "FHELinalg.sub_eint"(%arg0, %arg1) : (tensor<3x1x2x!FHE.eint<5>>, tensor<3x1x2x!FHE.eint<5>>) -> tensor<3x1x2x!FHE.eint<5>>
return %1 : tensor<3x1x2x!FHE.eint<5>>
}
)XXX");
const uint8_t a0[3][1][2]{
{{8, 10}},
{{12, 14}},
{{16, 18}},
};
const uint8_t a1[3][1][2]{
{{1, 2}},
{{4, 5}},
{{7, 8}},
};
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 2), {3, 1, 2});
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 2), {3, 1, 2});
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (uint64_t)3 * 1 * 2);
for (size_t i = 0; i < 3; i++) {
for (size_t j = 0; j < 1; j++) {
for (size_t k = 0; k < 2; k++) {
EXPECT_EQ((*res)[i * 2 + j + k], (uint64_t)a0[i][j][k] - a1[i][j][k]);
}
}
}
}
///////////////////////////////////////////////////////////////////////////////
// FHELinalg mul_eint_int ///////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////