mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): support HLFHE.neg_eint in MANP
This commit is contained in:
@@ -20,6 +20,7 @@ def MANP : FunctionPass<"MANP"> {
|
||||
- HLFHE.add_eint_int
|
||||
- HLFHE.add_eint
|
||||
- HLFHE.sub_int_eint
|
||||
- HLFHE.neg_eint
|
||||
- HLFHE.mul_eint_int
|
||||
- HLFHE.apply_lookup_table
|
||||
|
||||
@@ -47,6 +48,7 @@ def MANP : FunctionPass<"MANP"> {
|
||||
- HLFHE.add_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e, 1], [1, c])
|
||||
- HLFHE.add_eint(e0, e1) -> HLFHELinalg.dot_eint_int([e0, e1], [1, 1])
|
||||
- HLFHE.sub_int_eint(c, e) -> HLFHELinalg.dot_eint_int([e, c], [1, -1])
|
||||
- HLFHE.neg_eint(e) -> HLFHELinalg.dot_eint_int([e], [-1])
|
||||
- HLFHE.mul_eint_int(e, c) -> HLFHELinalg.dot_eint_int([e], [c])
|
||||
|
||||
Dependent dot operations, e.g.,
|
||||
@@ -85,6 +87,7 @@ def MANP : FunctionPass<"MANP"> {
|
||||
- HLFHE.add_eint_int(e, c) -> 1*1*sqN(e) + c*c*1*1 = sqN(e) + c*c
|
||||
- HLFHE.add_eint(e0, e1) -> 1*1*sqN(e0) + 1*1*sqN(e2) = sqN(e1) + sqN(e2)
|
||||
- HLFHE.sub_int_eint(c, e) -> 1*1*sqN(e) + c*c*(-1)*(-1) = sqN(e) + c*c
|
||||
- HLFHE.neg_eint(e) -> (-1)*(-1)*sqN(e) = sqN(e)
|
||||
- HLFHE.mul_eint_int(e, c) -> c*c*sqN(e)
|
||||
|
||||
The final, non-squared 2-norm of an operation is the square root of the
|
||||
|
||||
@@ -393,6 +393,22 @@ static llvm::APInt getSqMANP(
|
||||
return APIntWidthExtendUAdd(sqNorm, eNorm);
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHE.neg_eint` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::zamalang::HLFHE::NegEintOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
assert(
|
||||
operandMANPs.size() == 1 &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted operand");
|
||||
|
||||
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
return eNorm;
|
||||
}
|
||||
|
||||
// Calculates the squared Minimal Arithmetic Noise Padding of a dot operation
|
||||
// that is equivalent to an `HLFHE.mul_eint_int` operation.
|
||||
static llvm::APInt getSqMANP(
|
||||
@@ -663,6 +679,9 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
} else if (auto subIntEintOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHE::SubIntEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(subIntEintOp, operands);
|
||||
} else if (auto negEintOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHE::NegEintOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(negEintOp, operands);
|
||||
} else if (auto mulEintIntOp =
|
||||
llvm::dyn_cast<mlir::zamalang::HLFHE::MulEintIntOp>(op)) {
|
||||
norm2SqEquiv = getSqMANP(mulEintIntOp, operands);
|
||||
|
||||
@@ -86,6 +86,16 @@ func @single_dyn_sub_int_eint(%e: !HLFHE.eint<2>, %i: i3) -> !HLFHE.eint<2>
|
||||
|
||||
// -----
|
||||
|
||||
func @single_neg_eint(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
{
|
||||
// CHECK: %[[ret:.*]] = "HLFHE.neg_eint"(%[[op0:.*]]) {MANP = 1 : ui{{[0-9]+}}} : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHE.neg_eint"(%e) : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
|
||||
return %0 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @single_cst_mul_eint_int(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
{
|
||||
%cst = arith.constant 3 : i3
|
||||
@@ -197,3 +207,18 @@ func @chain_add_eint(%e0: !HLFHE.eint<2>, %e1: !HLFHE.eint<2>, %e2: !HLFHE.eint<
|
||||
|
||||
return %3 : !HLFHE.eint<2>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
func @chain_add_eint_neg_eint(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
{
|
||||
%cst0 = arith.constant 3 : i3
|
||||
|
||||
// CHECK: %[[ret:.*]] = "HLFHE.add_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
%0 = "HLFHE.add_eint_int"(%e, %cst0) : (!HLFHE.eint<2>, i3) -> !HLFHE.eint<2>
|
||||
// CHECK-NEXT: %[[ret:.*]] = "HLFHE.neg_eint"(%[[op0:.*]]) {MANP = 4 : ui{{[0-9]+}}} : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
%1 = "HLFHE.neg_eint"(%0) : (!HLFHE.eint<2>) -> !HLFHE.eint<2>
|
||||
|
||||
return %1 : !HLFHE.eint<2>
|
||||
}
|
||||
Reference in New Issue
Block a user