feat(compiler): support HLFHE.neg_eint in MANP

This commit is contained in:
youben11
2021-11-08 16:00:31 +01:00
parent 08869bc998
commit efacd7d8a1
3 changed files with 47 additions and 0 deletions

View File

@@ -20,6 +20,7 @@ def MANP : FunctionPass<"MANP"> {
- HLFHE.add_eint_int
- HLFHE.add_eint
- HLFHE.sub_int_eint
- HLFHE.neg_eint
- HLFHE.mul_eint_int
- HLFHE.apply_lookup_table
@@ -47,6 +48,7 @@ def MANP : FunctionPass<"MANP"> {
- HLFHE.add_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e, 1], [1, c])
- HLFHE.add_eint(e0, e1) -> HLFHELinalg.dot_eint_int([e0, e1], [1, 1])
- HLFHE.sub_int_eint(c, e) -> HLFHELinalg.dot_eint_int([e, c], [1, -1])
- HLFHE.neg_eint(e) -> HLFHELinalg.dot_eint_int([e], [-1])
- HLFHE.mul_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e], [c])
Dependent dot operations, e.g.,
@@ -85,6 +87,7 @@ def MANP : FunctionPass<"MANP"> {
- HLFHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c
- HLFHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
- HLFHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
- HLFHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e)
- HLFHE.mul_eint_int(e, c) -> c*c*sqN(e)
The final, non-squared 2-norm of an operation is the square root of the

View File

@@ -393,6 +393,22 @@ static llvm::APInt getSqMANP(
return APIntWidthExtendUAdd(sqNorm, eNorm);
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `HLFHE.neg_eint` operation.
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHE::NegEintOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(
operandMANPs.size() == 1 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
return eNorm;
}
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
// that is equivalent to an `HLFHE.mul_eint_int` operation.
static llvm::APInt getSqMANP(
@@ -663,6 +679,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
} else if (auto subIntEintOp =
llvm::dyn_cast<mlir::zamalang::HLFHE::SubIntEintOp>(op)) {
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
} else if (auto negEintOp =
llvm::dyn_cast<mlir::zamalang::HLFHE::NegEintOp>(op)) {
norm2SqEquiv = getSqMANP(negEintOp, operands);
} else if (auto mulEintIntOp =
llvm::dyn_cast<mlir::zamalang::HLFHE::MulEintIntOp>(op)) {
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);

View File

@@ -86,6 +86,16 @@ func @single_dyn_sub_int_eint(%e: !HLFHE.eint<2>, %i: i3) -> !HLFHE.eint<2>
// -----
func @single_neg_eint(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2>
{
// CHECK: %[[ret:.*]] = "HLFHE.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
%0 = "HLFHE.neg_eint"(%e) : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
return %0 : !HLFHE.eint<2>
}
// -----
func @single_cst_mul_eint_int(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2>
{
%cst = arith.constant 3 : i3
@@ -197,3 +207,18 @@ func @chain_add_eint(%e0: !HLFHE.eint<2>, %e1: !HLFHE.eint<2>, %e2: !HLFHE.eint<
return %3 : !HLFHE.eint<2>
}
// -----
func @chain_add_eint_neg_eint(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2>
{
%cst0 = arith.constant 3 : i3
// CHECK: %[[ret:.*]] = "HLFHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
%0 = "HLFHE.add_eint_int"(%e, %cst0) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
// CHECK-NEXT: %[[ret:.*]] = "HLFHE.neg_eint"(%[[op0:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
%1 = "HLFHE.neg_eint"(%0) : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
return %1 : !HLFHE.eint<2>
}