mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: implement all kinds of subtractions
This commit is contained in:
@@ -148,6 +148,71 @@ def SubIntEintOp : FHE_Op<"sub_int_eint"> {
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SubEintIntOp : FHE_Op<"sub_eint_int"> {
|
||||
|
||||
let summary = "Substract a clear integer from an encrypted integer";
|
||||
|
||||
let description = [{
|
||||
Substract a clear integer from an encrypted integer.
|
||||
The clear integer must have at most one more bit than the encrypted integer
|
||||
and the result must have the same width than the encrypted integer.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
|
||||
// error
|
||||
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i4) -> !FHE.eint<2>
|
||||
"FHE.sub_eint_int"(%i, %a) : (!FHE.eint<2>, i3) -> !FHE.eint<3>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins EncryptedIntegerType:$a, AnyInteger:$b);
|
||||
let results = (outs EncryptedIntegerType);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def SubEintOp : FHE_Op<"sub_eint"> {
|
||||
|
||||
let summary = "Subtracts two encrypted integers";
|
||||
|
||||
let description = [{
|
||||
Subtracts two encrypted integers
|
||||
The encrypted integers and the result must have the same width.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
// ok
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
|
||||
// error
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<3>) -> (!FHE.eint<2>)
|
||||
"FHE.sub_eint"(%a, %b): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<3>)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins EncryptedIntegerType:$a, EncryptedIntegerType:$b);
|
||||
let results = (outs EncryptedIntegerType);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$a, "Value":$b), [{
|
||||
build($_builder, $_state, a.getType(), a, b);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def NegEintOp : FHE_Op<"neg_eint"> {
|
||||
|
||||
let summary = "Negates an encrypted integer";
|
||||
@@ -224,7 +289,7 @@ def ApplyLookupTableEintOp : FHE_Op<"apply_lookup_table"> {
|
||||
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<2>)
|
||||
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<3>, tensor<4xi64>) -> (!FHE.eint<2>)
|
||||
|
||||
|
||||
// error
|
||||
"FHE.apply_lookup_table"(%a, %lut): (!FHE.eint<2>, tensor<8xi64>) -> (!FHE.eint<2>)
|
||||
```
|
||||
|
||||
@@ -181,6 +181,114 @@ def SubIntEintOp : FHELinalg_Op<"sub_int_eint", [TensorBroadcastingRules, Tensor
|
||||
];
|
||||
}
|
||||
|
||||
def SubEintIntOp : FHELinalg_Op<"sub_eint_int", [TensorBroadcastingRules, TensorBinaryEintInt]> {
|
||||
let summary = "Returns a tensor that contains the substraction of a tensor of clear integers from a tensor of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Performs a substraction following the broadcasting rules between a tensor of clear integers from a tensor of encrypted integers.
|
||||
The width of the clear integers must be less than or equals to the witdh of encrypted integers.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term substraction of `%a0` with `%a1`
|
||||
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
|
||||
|
||||
// Returns the term to term substraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
|
||||
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<1x4x4x!FHE.eint<4>>, tensor<4x1x4xi5>) -> tensor<4x4x4x!FHE.eint<4>>
|
||||
|
||||
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
|
||||
//
|
||||
// [1,2,3] [1] [0,2,3]
|
||||
// [4,5,6] - [2] = [2,3,4]
|
||||
// [7,8,9] [3] [4,5,6]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x1x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
|
||||
//
|
||||
// [1,2,3] [0,0,0]
|
||||
// [4,5,6] - [1,2,3] = [3,3,3]
|
||||
// [7,8,9] [6,6,6]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<1x3x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Same behavior than the previous one, but as the dimension #2 is missing of operand #2.
|
||||
"FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x!FHE.eint<4>>, tensor<3x3xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
|
||||
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
|
||||
build($_builder, $_state, lhs.getType(), lhs, rhs);
|
||||
}]>
|
||||
|
||||
];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def SubEintOp : FHELinalg_Op<"sub_eint", [TensorBroadcastingRules, TensorBinaryEint]> {
|
||||
let summary = "Returns a tensor that contains the subtraction of two tensor of encrypted integers.";
|
||||
|
||||
let description = [{
|
||||
Performs an subtraction follwing the broadcasting rules between two tensors of encrypted integers.
|
||||
The width of the encrypted integers must be equal.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
// Returns the term to term subtraction of `%a0` with `%a1`
|
||||
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<4>>, tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>>
|
||||
|
||||
// Returns the term to term subtraction of `%a0` with `%a1`, where dimensions equal to one are stretched.
|
||||
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<4>>, tensor<1x4x4x!FHE.eint<4>>) -> tensor<4x4x4x!FHE.eint<4>>
|
||||
|
||||
// Returns the substraction of a 3x3 matrix of integers and a 3x1 matrix (a column) of encrypted integers.
|
||||
//
|
||||
// [1,2,3] [1] [0,2,3]
|
||||
// [4,5,6] - [2] = [2,3,4]
|
||||
// [7,8,9] [3] [4,5,6]
|
||||
//
|
||||
// The dimension #1 of operand #2 is stretched as it is equals to 1.
|
||||
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Returns the substraction of a 3x3 matrix of integers and a 1x3 matrix (a line) of encrypted integers.
|
||||
//
|
||||
// [1,2,3] [0,0,0]
|
||||
// [4,5,6] - [1,2,3] = [3,3,3]
|
||||
// [7,8,9] [6,6,6]
|
||||
//
|
||||
// The dimension #2 of operand #2 is stretched as it is equals to 1.
|
||||
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
|
||||
// Same behavior than the previous one, but as the dimension #2 of operand #2 is missing.
|
||||
"FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$lhs,
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$rhs
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$lhs, "Value":$rhs), [{
|
||||
build($_builder, $_state, lhs.getType(), lhs, rhs);
|
||||
}]>
|
||||
];
|
||||
}
|
||||
|
||||
def NegEintOp : FHELinalg_Op<"neg_eint", [TensorUnaryEint]> {
|
||||
let summary = "Returns a tensor that contains the negation of a tensor of encrypted integers.";
|
||||
|
||||
|
||||
@@ -1613,6 +1613,14 @@ void FHETensorOpsToLinalg::runOnOperation() {
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::SubIntEintOp,
|
||||
mlir::concretelang::FHE::SubIntEintOp>>(
|
||||
&getContext());
|
||||
patterns.insert<
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::SubEintIntOp,
|
||||
mlir::concretelang::FHE::SubEintIntOp>>(
|
||||
&getContext());
|
||||
patterns.insert<
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::SubEintOp,
|
||||
mlir::concretelang::FHE::SubEintOp>>(
|
||||
&getContext());
|
||||
patterns.insert<
|
||||
FHELinalgOpToLinalgGeneric<mlir::concretelang::FHELinalg::MulEintIntOp,
|
||||
mlir::concretelang::FHE::MulEintIntOp>>(
|
||||
|
||||
@@ -115,6 +115,85 @@ struct ApplyLookupTableEintOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `FHE.sub_eint_int`
|
||||
// operators to a negation and an addition.
|
||||
struct SubEintIntOpPattern : public mlir::OpRewritePattern<FHE::SubEintIntOp> {
|
||||
SubEintIntOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::SubEintIntOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintIntOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
|
||||
mlir::Value lhs = op.getOperand(0);
|
||||
mlir::Value rhs = op.getOperand(1);
|
||||
|
||||
mlir::Type rhsType = rhs.getType();
|
||||
mlir::Attribute minusOneAttr = mlir::IntegerAttr::get(rhsType, -1);
|
||||
mlir::Value minusOne =
|
||||
rewriter.create<mlir::arith::ConstantOp>(location, minusOneAttr)
|
||||
.getResult();
|
||||
|
||||
mlir::Value negative =
|
||||
rewriter.create<mlir::arith::MulIOp>(location, rhs, minusOne)
|
||||
.getResult();
|
||||
|
||||
FHEToTFHETypeConverter converter;
|
||||
auto resultTy = converter.convertType(op.getType());
|
||||
|
||||
auto addition =
|
||||
rewriter.create<TFHE::AddGLWEIntOp>(location, resultTy, lhs, negative);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, {addition.getResult()});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
// This rewrite pattern transforms any instance of `FHE.sub_eint`
|
||||
// operators to a negation and an addition.
|
||||
struct SubEintOpPattern : public mlir::OpRewritePattern<FHE::SubEintOp> {
|
||||
SubEintOpPattern(mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<FHE::SubEintOp>(context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(FHE::SubEintOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Location location = op.getLoc();
|
||||
|
||||
mlir::Value lhs = op.getOperand(0);
|
||||
mlir::Value rhs = op.getOperand(1);
|
||||
|
||||
FHEToTFHETypeConverter converter;
|
||||
|
||||
auto rhsTy = converter.convertType(rhs.getType());
|
||||
auto negative = rewriter.create<TFHE::NegGLWEOp>(location, rhsTy, rhs);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, negative, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
auto resultTy = converter.convertType(op.getType());
|
||||
auto addition = rewriter.create<TFHE::AddGLWEOp>(location, resultTy, lhs,
|
||||
negative.getResult());
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, addition, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
rewriter.replaceOp(op, {addition.getResult()});
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
void FHEToTFHEPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
@@ -123,6 +202,7 @@ void FHEToTFHEPass::runOnOperation() {
|
||||
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<mlir::concretelang::TFHE::TFHEDialect>();
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
|
||||
// Make sure that no ops from `FHE` remain after the lowering
|
||||
target.addIllegalDialect<mlir::concretelang::FHE::FHEDialect>();
|
||||
@@ -155,6 +235,9 @@ void FHEToTFHEPass::runOnOperation() {
|
||||
patterns.getContext(), converter);
|
||||
|
||||
patterns.add<ApplyLookupTableEintOpPattern>(&getContext());
|
||||
patterns.add<SubEintOpPattern>(&getContext());
|
||||
patterns.add<SubEintIntOpPattern>(&getContext());
|
||||
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
FHEToTFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
@@ -429,6 +429,57 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.sub_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::SubEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
mlir::Type iTy = op->getOpOperand(1).get().getType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only subtractions with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
|
||||
if (cstOp) {
|
||||
// For constant plaintext operands simply use the constant value
|
||||
mlir::IntegerAttr attr = cstOp->getAttrOfType<mlir::IntegerAttr>("value");
|
||||
sqNorm = APIntWidthExtendSqForConstant(attr.getValue());
|
||||
} else {
|
||||
// For dynamic plaintext operands conservatively assume that the integer has
|
||||
// its maximum possible value
|
||||
sqNorm = conservativeIntNorm2Sq(iTy);
|
||||
}
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.sub_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHE::SubEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHE.neg_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -575,6 +626,58 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::SubEintIntOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType op1Ty =
|
||||
op->getOpOperand(1).get().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Type iTy = op1Ty.getElementType();
|
||||
|
||||
assert(iTy.isSignlessInteger() &&
|
||||
"Only subtractions with signless integers are currently allowed");
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt sqNorm;
|
||||
|
||||
mlir::arith::ConstantOp cstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
mlir::DenseIntElementsAttr denseVals =
|
||||
cstOp ? cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
|
||||
if (denseVals) {
|
||||
sqNorm = maxIntNorm2Sq(denseVals);
|
||||
} else {
|
||||
// For dynamic plaintext operands conservatively assume that the integer has
|
||||
// its maximum possible value
|
||||
sqNorm = conservativeIntNorm2Sq(iTy);
|
||||
}
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::SubEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
assert(operandMANPs.size() == 2 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
operandMANPs[1]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operands");
|
||||
|
||||
llvm::APInt a = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
llvm::APInt b = operandMANPs[1]->getValue().getMANP().getValue();
|
||||
|
||||
return APIntWidthExtendUAdd(a, b);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `FHELinalg.neg_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -1192,6 +1295,12 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto subIntEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::SubIntEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
|
||||
} else if (auto subEintIntOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::SubEintIntOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(subEintIntOp, operands);
|
||||
} else if (auto subEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::SubEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(subEintOp, operands);
|
||||
} else if (auto negEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHE::NegEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(negEintOp, operands);
|
||||
@@ -1219,6 +1328,14 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::SubIntEintOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
|
||||
} else if (auto subEintIntOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::SubEintIntOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(subEintIntOp, operands);
|
||||
} else if (auto subEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::SubEintOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(subEintOp, operands);
|
||||
} else if (auto negEintOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::NegEintOp>(
|
||||
op)) {
|
||||
|
||||
@@ -89,6 +89,35 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::Operation &op,
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
::mlir::LogicalResult SubEintIntOp::verify() {
|
||||
auto a = this->a().getType().cast<EncryptedIntegerType>();
|
||||
auto b = this->b().getType().cast<IntegerType>();
|
||||
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
out)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!verifyEncryptedIntegerAndIntegerInputsConsistency(*this->getOperation(),
|
||||
a, b)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
::mlir::LogicalResult SubEintOp::verify() {
|
||||
auto a = this->a().getType().cast<EncryptedIntegerType>();
|
||||
auto b = this->b().getType().cast<EncryptedIntegerType>();
|
||||
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
|
||||
if (!verifyEncryptedIntegerInputAndResultConsistency(*this->getOperation(), a,
|
||||
out)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
if (!verifyEncryptedIntegerInputsConsistency(*this->getOperation(), a, b)) {
|
||||
return ::mlir::failure();
|
||||
}
|
||||
return ::mlir::success();
|
||||
}
|
||||
|
||||
::mlir::LogicalResult NegEintOp::verify() {
|
||||
auto a = this->a().getType().cast<EncryptedIntegerType>();
|
||||
auto out = this->getResult().getType().cast<EncryptedIntegerType>();
|
||||
@@ -147,6 +176,19 @@ OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Avoid subtraction with constant 0
|
||||
OpFoldResult SubEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toSub = operands[1].dyn_cast_or_null<mlir::IntegerAttr>();
|
||||
if (toSub != nullptr) {
|
||||
auto intToSub = toSub.getInt();
|
||||
if (intToSub == 0) {
|
||||
return getOperand(0);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Avoid multiplication with constant 1
|
||||
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
|
||||
@@ -1731,6 +1731,20 @@ OpFoldResult AddEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
// Avoid subtraction with constant tensor of 0s
|
||||
OpFoldResult SubEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
auto toSub = operands[1].dyn_cast_or_null<mlir::DenseIntElementsAttr>();
|
||||
if (toSub == nullptr)
|
||||
return nullptr;
|
||||
for (auto it = toSub.begin(); it != toSub.end(); it++) {
|
||||
if (*it != 0) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return getOperand(0);
|
||||
}
|
||||
|
||||
// Avoid multiplication with constant tensor of 1s
|
||||
OpFoldResult MulEintIntOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 2);
|
||||
|
||||
@@ -100,6 +100,60 @@ func @single_dyn_sub_int_eint(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_sub_eint_int(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
{
|
||||
%cst = arith.constant 3 : i3
|
||||
|
||||
// CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
|
||||
return %0 : !FHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_sub_eint_int_neg(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
{
|
||||
%cst = arith.constant -3 : i3
|
||||
|
||||
// CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%0 = "FHE.sub_eint_int"(%e, %cst) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
|
||||
return %0 : !FHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_dyn_sub_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
|
||||
{
|
||||
// sqrt(1 + (2^2-1)^2) = 3.16
|
||||
// CHECK: %[[ret:.*]] = "FHE.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
%0 = "FHE.sub_eint_int"(%e, %i) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
|
||||
return %0 : !FHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @chain_sub_eint(%e0: !FHE.eint<2>, %e1: !FHE.eint<2>, %e2: !FHE.eint<2>, %e3: !FHE.eint<2>, %e4: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
{
|
||||
// CHECK: %[[V0:.*]] = "FHE.sub_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
%0 = "FHE.sub_eint"(%e0, %e1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%[[V0]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
%1 = "FHE.sub_eint"(%0, %e2) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint"(%[[V1]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
%2 = "FHE.sub_eint"(%1, %e3) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
// CHECK-NEXT: %[[V3:.*]] = "FHE.sub_eint"(%[[V2]], %[[op1:.*]]) {MANP = 3 : ui{{[0-9]+}}} : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
%3 = "FHE.sub_eint"(%2, %e4) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
return %3 : !FHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_neg_eint(%e: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHE.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>) -> !FHE.eint<2>
|
||||
@@ -147,7 +201,7 @@ func @single_dyn_mul_eint_int(%e: !FHE.eint<2>, %i: i3) -> !FHE.eint<2>
|
||||
// -----
|
||||
|
||||
func @single_apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<2> {
|
||||
// CHECK: %[[ret:.*]] = "FHE.apply_lookup_table"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
// CHECK: %[[ret:.*]] = "FHE.apply_lookup_table"(%[[op0:.*]], %[[op1:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
@@ -70,6 +70,41 @@ func @single_cst_sub_int_eint_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> t
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_sub_eint_int(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
%cst = arith.constant dense<[0, 1, 2, 3, 3, 2, 1, 0]> : tensor<8xi3>
|
||||
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_sub_eint_int_from_cst_elements(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
%cst1 = arith.constant 1 : i3
|
||||
%cst = tensor.from_elements %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1, %cst1: tensor<8xi3>
|
||||
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.sub_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.sub_eint_int"(%e, %cst) : (tensor<8x!FHE.eint<2>>, tensor<8xi3>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_sub_eint(%e0: tensor<8x!FHE.eint<2>>, %e1: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.sub_eint"(%[[op0:.*]], %[[op1:.*]]) {MANP = 2 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
%0 = "FHELinalg.sub_eint"(%e0, %e1) : (tensor<8x!FHE.eint<2>>, tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
return %0 : tensor<8x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_neg_eint(%e: tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "FHELinalg.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (tensor<8x!FHE.eint<2>>) -> tensor<8x!FHE.eint<2>>
|
||||
|
||||
@@ -9,6 +9,15 @@ func @add_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
|
||||
|
||||
%0 = arith.constant 0 : i3
|
||||
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @mul_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: return %arg0 : !FHE.eint<2>
|
||||
|
||||
@@ -49,6 +49,26 @@ func @sub_int_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @sub_eint_int(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i3
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.sub_eint_int"(%arg0, %[[V1]]) : (!FHE.eint<2>, i3) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: return %[[V2]] : !FHE.eint<2>
|
||||
|
||||
%0 = arith.constant 1 : i3
|
||||
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<2>, i3) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @sub_eint(%arg0: !FHE.eint<2>, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.sub_eint"(%arg0, %arg1) : (!FHE.eint<2>, !FHE.eint<2>) -> !FHE.eint<2>
|
||||
// CHECK-NEXT: return %[[V1]] : !FHE.eint<2>
|
||||
|
||||
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<2>, !FHE.eint<2>) -> (!FHE.eint<2>)
|
||||
return %1: !FHE.eint<2>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2>
|
||||
func @neg_eint(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.neg_eint"(%arg0) : (!FHE.eint<2>) -> !FHE.eint<2>
|
||||
|
||||
@@ -27,6 +27,33 @@ func @add_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FH
|
||||
return %1: tensor<4x3x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @sub_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_int_1D(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[0, 0, 0, 0]> : tensor<4xi3>
|
||||
%1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @sub_eint_int_1D_broadcast(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_int_1D_broadcast(%a0: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[0]> : tensor<1xi3>
|
||||
%1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<1xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @sub_eint_int_2D_broadcast(%[[a0:.*]]: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x3x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_int_2D_broadcast(%a0: tensor<4x3x!FHE.eint<2>>) -> tensor<4x3x!FHE.eint<2>> {
|
||||
%a1 = arith.constant dense<[[0]]> : tensor<1x1xi3>
|
||||
%1 = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<4x3x!FHE.eint<2>>, tensor<1x1xi3>) -> tensor<4x3x!FHE.eint<2>>
|
||||
return %1: tensor<4x3x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// CHECK: func @mul_eint_int_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: return %[[a0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
|
||||
@@ -110,7 +110,7 @@ func @add_eint_broadcast_2(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<3x4x!FHE.ein
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.sub_eint_int
|
||||
// FHELinalg.sub_int_eint
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// 1D tensor
|
||||
@@ -164,6 +164,116 @@ func @sub_int_eint_broadcast_2(%a0: tensor<3x4xi3>, %a1: tensor<4x!FHE.eint<2>>)
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.sub_eint_int
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// 1D tensor
|
||||
// CHECK: func @sub_eint_int_1D(%[[a0:.*]]: tensor<4xi3>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_int_1D(%a0: tensor<4xi3>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<2>>, tensor<4xi3>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// 2D tensor
|
||||
// CHECK: func @sub_eint_int_2D(%[[a0:.*]]: tensor<2x4xi3>, %[[a1:.*]]: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<2x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_int_2D(%a0: tensor<2x4xi3>, %a1: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4xi3>) -> tensor<2x4x!FHE.eint<2>>
|
||||
return %1: tensor<2x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// 10D tensor
|
||||
// CHECK: func @sub_eint_int_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10xi3>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_int_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10xi3>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10xi3>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
|
||||
return %1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// Broadcasting with tensor with dimensions equals to one
|
||||
// CHECK: func @sub_eint_int_broadcast_1(%[[a0:.*]]: tensor<3x4x1xi3>, %[[a1:.*]]: tensor<1x4x5x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_int_broadcast_1(%a0: tensor<3x4x1xi3>, %a1: tensor<1x4x5x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1xi3>) -> tensor<3x4x5x!FHE.eint<2>>
|
||||
return %1: tensor<3x4x5x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// Broadcasting with a tensor less dimensions of another
|
||||
// CHECK: func @sub_eint_int_broadcast_2(%[[a0:.*]]: tensor<3x4xi3>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint_int"(%[[a1]], %[[a0]]) : (tensor<4x!FHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_int_broadcast_2(%a0: tensor<3x4xi3>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> {
|
||||
%1 ="FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!FHE.eint<2>>
|
||||
return %1: tensor<3x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.sub_eint
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// 1D tensor
|
||||
// CHECK: func @sub_eint_1D(%[[a0:.*]]: tensor<4x!FHE.eint<2>>, %[[a1:.*]]: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_1D(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<4x!FHE.eint<2>>) -> tensor<4x!FHE.eint<2>>
|
||||
return %1: tensor<4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// 2D tensor
|
||||
// CHECK: func @sub_eint_2D(%[[a0:.*]]: tensor<2x4x!FHE.eint<2>>, %[[a1:.*]]: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<2x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_2D(%a0: tensor<2x4x!FHE.eint<2>>, %a1: tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<2x4x!FHE.eint<2>>, tensor<2x4x!FHE.eint<2>>) -> tensor<2x4x!FHE.eint<2>>
|
||||
return %1: tensor<2x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// 10D tensor
|
||||
// CHECK: func @sub_eint_10D(%[[a0:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, %[[a1:.*]]: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_10D(%a0: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, %a1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>, tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>) -> tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
|
||||
return %1: tensor<1x2x3x4x5x6x7x8x9x10x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// Broadcasting with tensor with dimensions equals to one
|
||||
// CHECK: func @sub_eint_broadcast_1(%[[a0:.*]]: tensor<1x4x5x!FHE.eint<2>>, %[[a1:.*]]: tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<3x4x5x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_broadcast_1(%a0: tensor<1x4x5x!FHE.eint<2>>, %a1: tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<1x4x5x!FHE.eint<2>>, tensor<3x4x1x!FHE.eint<2>>) -> tensor<3x4x5x!FHE.eint<2>>
|
||||
return %1: tensor<3x4x5x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// Broadcasting with a tensor less dimensions of another
|
||||
// CHECK: func @sub_eint_broadcast_2(%[[a0:.*]]: tensor<4x!FHE.eint<2>>, %[[a1:.*]]: tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.sub_eint"(%[[a0]], %[[a1]]) : (tensor<4x!FHE.eint<2>>, tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V0]] : tensor<3x4x!FHE.eint<2>>
|
||||
// CHECK-NEXT: }
|
||||
func @sub_eint_broadcast_2(%a0: tensor<4x!FHE.eint<2>>, %a1: tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>> {
|
||||
%1 ="FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<2>>, tensor<3x4x!FHE.eint<2>>) -> tensor<3x4x!FHE.eint<2>>
|
||||
return %1: tensor<3x4x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.neg_eint
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
@@ -63,7 +63,7 @@ tests:
|
||||
outputs:
|
||||
- scalar: 3
|
||||
---
|
||||
description: sub_eint_int_cst
|
||||
description: sub_int_eint_cst
|
||||
program: |
|
||||
func @main(%arg0: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
%0 = arith.constant 7 : i3
|
||||
@@ -88,6 +88,69 @@ tests:
|
||||
outputs:
|
||||
- scalar: 3
|
||||
---
|
||||
description: sub_eint_int_cst
|
||||
program: |
|
||||
func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%0 = arith.constant 7 : i6
|
||||
%1 = "FHE.sub_eint_int"(%arg0, %0): (!FHE.eint<5>, i6) -> (!FHE.eint<5>)
|
||||
return %1: !FHE.eint<5>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 10
|
||||
outputs:
|
||||
- scalar: 3
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
outputs:
|
||||
- scalar: 0
|
||||
---
|
||||
description: sub_eint_int_arg
|
||||
program: |
|
||||
func @main(%arg0: !FHE.eint<4>, %arg1: i5) -> !FHE.eint<4> {
|
||||
%1 = "FHE.sub_eint_int"(%arg0, %arg1): (!FHE.eint<4>, i5) -> (!FHE.eint<4>)
|
||||
return %1: !FHE.eint<4>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 2
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 3
|
||||
---
|
||||
description: sub_eint
|
||||
program: |
|
||||
func @main(%arg0: !FHE.eint<4>, %arg1: !FHE.eint<4>) -> !FHE.eint<4> {
|
||||
%1 = "FHE.sub_eint"(%arg0, %arg1): (!FHE.eint<4>, !FHE.eint<4>) -> (!FHE.eint<4>)
|
||||
return %1: !FHE.eint<4>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 2
|
||||
- scalar: 2
|
||||
outputs:
|
||||
- scalar: 0
|
||||
- inputs:
|
||||
- scalar: 3
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 2
|
||||
- inputs:
|
||||
- scalar: 7
|
||||
- scalar: 4
|
||||
outputs:
|
||||
- scalar: 3
|
||||
---
|
||||
description: sub_int_eint_arg
|
||||
program: |
|
||||
func @main(%arg0: i3, %arg1: !FHE.eint<2>) -> !FHE.eint<2> {
|
||||
|
||||
@@ -791,6 +791,447 @@ TEST(End2EndJit_FHELinalg, sub_int_eint_matrix_line_missing_dim) {
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FHELinalg sub_eint_int ///////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<4xi5>, %a1: tensor<4x!FHE.eint<4>>) -> tensor<4x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<4x!FHE.eint<4>>, tensor<4xi5>) -> tensor<4x!FHE.eint<4>>
|
||||
return %res : tensor<4x!FHE.eint<4>>
|
||||
}
|
||||
)XXX");
|
||||
std::vector<uint8_t> a0{31, 6, 2, 3};
|
||||
std::vector<uint8_t> a1{32, 9, 12, 9};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(a0);
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(a1);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)4);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
EXPECT_EQ((*res)[i], (uint64_t)a1[i] - a0[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_int_term_to_term_broadcast) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<4x1x4xi8>, %a1: tensor<1x4x4x!FHE.eint<7>>) -> tensor<4x4x4x!FHE.eint<7>> {
|
||||
%res = "FHELinalg.sub_eint_int"(%a1, %a0) : (tensor<1x4x4x!FHE.eint<7>>, tensor<4x1x4xi8>) -> tensor<4x4x4x!FHE.eint<7>>
|
||||
return %res : tensor<4x4x4x!FHE.eint<7>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[4][1][4]{
|
||||
{{1, 2, 3, 4}},
|
||||
{{5, 6, 7, 8}},
|
||||
{{9, 10, 11, 12}},
|
||||
{{13, 14, 15, 16}},
|
||||
};
|
||||
const uint8_t a1[1][4][4]{
|
||||
{
|
||||
{1, 2, 3, 4},
|
||||
{5, 6, 7, 8},
|
||||
{9, 10, 11, 12},
|
||||
{13, 14, 15, 16},
|
||||
},
|
||||
};
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 4 * 1 * 4), {4, 1, 4});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 1 * 4 * 4), {1, 4, 4});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)4 * 4 * 4);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
for (size_t j = 0; j < 4; j++) {
|
||||
for (size_t k = 0; k < 4; k++) {
|
||||
uint8_t expected = a1[i][0][k] - a0[0][j][k];
|
||||
EXPECT_EQ((*res)[i * 16 + j * 4 + k], (uint64_t)expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_column) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x1xi5>) ->
|
||||
tensor<3x3x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[3][1]{
|
||||
{1},
|
||||
{2},
|
||||
{3},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 1), {3, 1});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_line) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<1x3xi5>) ->
|
||||
tensor<3x3x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[1][3]{
|
||||
{1, 2, 3},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 1), {1, 3});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_int_matrix_line_missing_dim) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3xi5>) -> tensor<3x3x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint_int"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3xi5>) -> tensor<3x3x!FHE.eint<4>>
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[1][3]{
|
||||
{1, 2, 3},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 1), {3});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FHELinalg add_eint ///////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_term_to_term) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<4x!FHE.eint<6>>, %a1: tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>> {
|
||||
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x!FHE.eint<6>>, tensor<4x!FHE.eint<6>>) -> tensor<4x!FHE.eint<6>>
|
||||
return %res : tensor<4x!FHE.eint<6>>
|
||||
}
|
||||
)XXX");
|
||||
|
||||
std::vector<uint8_t> a0{31, 6, 12, 9};
|
||||
std::vector<uint8_t> a1{4, 2, 9, 3};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(a0);
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(a1);
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)4);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
EXPECT_EQ((*res)[i], (uint64_t)a0[i] - a1[i])
|
||||
<< "result differ at pos " << i << ", expect " << a0[i] + a1[i]
|
||||
<< " got " << (*res)[i];
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_term_to_term_broadcast) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<4x1x4x!FHE.eint<5>>, %a1: tensor<1x4x4x!FHE.eint<5>>) -> tensor<4x4x4x!FHE.eint<5>> {
|
||||
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<4x1x4x!FHE.eint<5>>, tensor<1x4x4x!FHE.eint<5>>) ->
|
||||
tensor<4x4x4x!FHE.eint<5>> return %res : tensor<4x4x4x!FHE.eint<5>>
|
||||
}
|
||||
)XXX");
|
||||
uint8_t a0[4][1][4]{
|
||||
{{10, 20, 30, 40}},
|
||||
{{5, 6, 7, 8}},
|
||||
{{9, 10, 11, 12}},
|
||||
{{13, 14, 15, 16}},
|
||||
};
|
||||
uint8_t a1[1][4][4]{
|
||||
{
|
||||
{1, 2, 3, 4},
|
||||
{4, 3, 2, 1},
|
||||
{3, 1, 4, 2},
|
||||
{2, 4, 1, 3},
|
||||
},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::MutableArrayRef<uint8_t>((uint8_t *)a0, 4 * 1 * 4), {4, 1, 4});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::MutableArrayRef<uint8_t>((uint8_t *)a1, 1 * 4 * 4), {1, 4, 4});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)4 * 4 * 4);
|
||||
|
||||
for (size_t i = 0; i < 4; i++) {
|
||||
for (size_t j = 0; j < 4; j++) {
|
||||
for (size_t k = 0; k < 4; k++) {
|
||||
EXPECT_EQ((*res)[i * 16 + j * 4 + k],
|
||||
(uint64_t)a0[i][0][k] - a1[0][j][k])
|
||||
<< "result differ at pos " << i << "," << j << "," << k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_matrix_column) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x1x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[3][1]{
|
||||
{1},
|
||||
{2},
|
||||
{3},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 1), {3, 1});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_matrix_line) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<1x3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[1][3]{
|
||||
{1, 2, 3},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 1), {1, 3});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_matrix_line_missing_dim) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%a0: tensor<3x3x!FHE.eint<4>>, %a1: tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>> {
|
||||
%res = "FHELinalg.sub_eint"(%a0, %a1) : (tensor<3x3x!FHE.eint<4>>, tensor<3x!FHE.eint<4>>) -> tensor<3x3x!FHE.eint<4>>
|
||||
return %res : tensor<3x3x!FHE.eint<4>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[3][3]{
|
||||
{1, 2, 3},
|
||||
{4, 5, 6},
|
||||
{7, 8, 9},
|
||||
};
|
||||
const uint8_t a1[1][3]{
|
||||
{1, 2, 3},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 3), {3, 3});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 1), {3});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 3);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 3; j++) {
|
||||
EXPECT_EQ((*res)[i * 3 + j], (uint64_t)a0[i][j] - a1[0][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_FHELinalg, sub_eint_tensor_dim_equals_1) {
|
||||
|
||||
checkedJit(lambda, R"XXX(
|
||||
func @main(%arg0: tensor<3x1x2x!FHE.eint<5>>, %arg1: tensor<3x1x2x!FHE.eint<5>>) -> tensor<3x1x2x!FHE.eint<5>> {
|
||||
%1 = "FHELinalg.sub_eint"(%arg0, %arg1) : (tensor<3x1x2x!FHE.eint<5>>, tensor<3x1x2x!FHE.eint<5>>) -> tensor<3x1x2x!FHE.eint<5>>
|
||||
return %1 : tensor<3x1x2x!FHE.eint<5>>
|
||||
}
|
||||
)XXX");
|
||||
const uint8_t a0[3][1][2]{
|
||||
{{8, 10}},
|
||||
{{12, 14}},
|
||||
{{16, 18}},
|
||||
};
|
||||
const uint8_t a1[3][1][2]{
|
||||
{{1, 2}},
|
||||
{{4, 5}},
|
||||
{{7, 8}},
|
||||
};
|
||||
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg0(llvm::ArrayRef<uint8_t>((const uint8_t *)a0, 3 * 2), {3, 1, 2});
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>
|
||||
arg1(llvm::ArrayRef<uint8_t>((const uint8_t *)a1, 3 * 2), {3, 1, 2});
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> res =
|
||||
lambda.operator()<std::vector<uint64_t>>({&arg0, &arg1});
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (uint64_t)3 * 1 * 2);
|
||||
|
||||
for (size_t i = 0; i < 3; i++) {
|
||||
for (size_t j = 0; j < 1; j++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
EXPECT_EQ((*res)[i * 2 + j + k], (uint64_t)a0[i][j][k] - a1[i][j][k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// FHELinalg mul_eint_int ///////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Reference in New Issue
Block a user