From efacd7d8a1d3ac2c7f09554209d84145fcbfb85d Mon Sep 17 00:00:00 2001 From: youben11 Date: Mon, 8 Nov 2021 16:00:31 +0100 Subject: [PATCH] feat(compiler): support HLFHE.neg_eint in MANP --- .../zamalang/Dialect/HLFHE/Analysis/MANP.td | 3 +++ compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 19 ++++++++++++++ .../tests/Dialect/HLFHE/Analysis/MANP.mlir | 25 +++++++++++++++++++ 3 files changed, 47 insertions(+) diff --git a/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td index ee6753db9..80e1cbf21 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td +++ b/compiler/include/zamalang/Dialect/HLFHE/Analysis/MANP.td @@ -20,6 +20,7 @@ def MANP : FunctionPass<"MANP"> { - HLFHE.add_eint_int - HLFHE.add_eint - HLFHE.sub_int_eint + - HLFHE.neg_eint - HLFHE.mul_eint_int - HLFHE.apply_lookup_table @@ -47,6 +48,7 @@ def MANP : FunctionPass<"MANP"> { - HLFHE.add_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e, 1], [1, c]) - HLFHE.add_eint(e0, e1) -> HLFHELinalg.dot_eint_int([e0, e1], [1, 1]) - HLFHE.sub_int_eint(c, e) -> HLFHELinalg.dot_eint_int([e, c], [1, -1]) + - HLFHE.neg_eint(e) -> HLFHELinalg.dot_eint_int([e], [-1]) - HLFHE.mul_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e], [c]) Dependent dot operations, e.g., @@ -85,6 +87,7 @@ def MANP : FunctionPass<"MANP"> { - HLFHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c - HLFHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2) - HLFHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c + - HLFHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e) - HLFHE.mul_eint_int(e, c) -> c*c*sqN(e) The final, non-squared 2-norm of an operation is the square root of the diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index 199309cf7..0a4d1f097 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -393,6 +393,22 @@ static llvm::APInt getSqMANP( return APIntWidthExtendUAdd(sqNorm, eNorm); } +// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation +// that is equivalent to an `HLFHE.neg_eint` operation. +static llvm::APInt getSqMANP( + mlir::zamalang::HLFHE::NegEintOp op, + llvm::ArrayRef *> operandMANPs) { + + assert( + operandMANPs.size() == 1 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted operand"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); + + return eNorm; +} + // Calculates the squared Minimal Arithmetic Noise Padding of a dot operation // that is equivalent to an `HLFHE.mul_eint_int` operation. static llvm::APInt getSqMANP( @@ -663,6 +679,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { } else if (auto subIntEintOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(subIntEintOp, operands); + } else if (auto negEintOp = + llvm::dyn_cast(op)) { + norm2SqEquiv = getSqMANP(negEintOp, operands); } else if (auto mulEintIntOp = llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir index 9a1348e67..9a95b96c0 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir @@ -86,6 +86,16 @@ func @single_dyn_sub_int_eint(%e: !HLFHE.eint<2>, %i: i3) -> !HLFHE.eint<2> // ----- +func @single_neg_eint(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2> +{ + // CHECK: %[[ret:.*]] = "HLFHE.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!HLFHE.eint<2>) -> !HLFHE.eint<2> + %0 = "HLFHE.neg_eint"(%e) : (!HLFHE.eint<2>) -> !HLFHE.eint<2> + + return %0 : !HLFHE.eint<2> +} + +// ----- + func @single_cst_mul_eint_int(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %cst = arith.constant 3 : i3 @@ -197,3 +207,18 @@ func @chain_add_eint(%e0: !HLFHE.eint<2>, %e1: !HLFHE.eint<2>, %e2: !HLFHE.eint< return %3 : !HLFHE.eint<2> } + + +// ----- + +func @chain_add_eint_neg_eint(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2> +{ + %cst0 = arith.constant 3 : i3 + + // CHECK: %[[ret:.*]] = "HLFHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> + %0 = "HLFHE.add_eint_int"(%e, %cst0) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2> + // CHECK-NEXT: %[[ret:.*]] = "HLFHE.neg_eint"(%[[op0:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>) -> !HLFHE.eint<2> + %1 = "HLFHE.neg_eint"(%0) : (!HLFHE.eint<2>) -> !HLFHE.eint<2> + + return %1 : !HLFHE.eint<2> +} \ No newline at end of file