diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index bc35a076a..edaefda90 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -212,7 +212,8 @@ static std::string APIntToStringValUnsigned(const llvm::APInt &i) { // Calculates the square of the 2-norm of a tensor initialized with a // dense matrix of constant, signless integers. Aborts if the value // type or initialization of of `cstOp` is incorrect. -static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) { +static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp, + llvm::APInt eNorm) { mlir::DenseIntElementsAttr denseVals = cstOp->getAttrOfType("value"); @@ -230,8 +231,9 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) { llvm::APInt accu{1, 0, false}; for (llvm::APInt val : denseVals.getValues()) { - llvm::APInt valSq = APIntWidthExtendUSq(val); - accu = APIntWidthExtendUAdd(accu, valSq); + llvm::APInt valSqNorm = APIntWidthExtendUSq(val); + llvm::APInt mulSqNorm = APIntWidthExtendUMul(valSqNorm, eNorm); + accu = APIntWidthExtendUAdd(accu, mulSqNorm); } return accu; @@ -241,7 +243,8 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) { // integers by conservatively assuming that the dynamic values are the // maximum for the integer width. Aborts if the tensor type `tTy` is // incorrect. -static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) { +static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy, + llvm::APInt eNorm) { assert(tTy && tTy.getElementType().isSignlessInteger() && tTy.hasStaticShape() && tTy.getRank() == 1 && "Plaintext operand must be a statically shaped 1D tensor of integers"); @@ -254,6 +257,7 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) { llvm::APInt maxVal = APInt::getMaxValue(elWidth); llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal); + llvm::APInt maxMulSqNorm = APIntWidthExtendUMul(maxValSq, eNorm); // Calculate number of bits for APInt to store number of elements uint64_t nElts = (uint64_t)tTy.getNumElements(); @@ -262,7 +266,7 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) { llvm::APInt nEltsAP{nEltsBits, nElts, false}; - return APIntWidthExtendUMul(maxValSq, nEltsAP); + return APIntWidthExtendUMul(maxMulSqNorm, nEltsAP); } // Calculates the squared Minimal Arithmetic Noise Padding of an @@ -270,9 +274,12 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) { static llvm::APInt getSqMANP( mlir::zamalang::HLFHELinalg::Dot op, llvm::ArrayRef *> operandMANPs) { - assert(op->getOpOperand(0).get().isa() && - "Only dot operations with tensors that are function arguments are " - "currently supported"); + assert(operandMANPs.size() == 2 && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operands"); + + llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue(); mlir::arith::ConstantOp cstOp = llvm::dyn_cast_or_null( @@ -281,7 +288,7 @@ static llvm::APInt getSqMANP( if (cstOp) { // Dot product between a vector of encrypted integers and a vector // of plaintext constants -> return 2-norm of constant vector - return denseCstTensorNorm2Sq(cstOp); + return denseCstTensorNorm2Sq(cstOp, eNorm); } else { // Dot product between a vector of encrypted integers and a vector // of dynamic plaintext values -> conservatively assume that all @@ -292,7 +299,7 @@ static llvm::APInt getSqMANP( .getType() .dyn_cast_or_null(); - return denseDynTensorNorm2Sq(tTy); + return denseDynTensorNorm2Sq(tTy, eNorm); } } diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir index 8ac034817..75a1bebf8 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP.mlir @@ -10,28 +10,6 @@ func @single_zero() -> !HLFHE.eint<2> // ----- -func @single_cst_dot(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2> -{ - %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3> - - // CHECK: %[[ret:.*]] = "HLFHELinalg.dot_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> - %0 = "HLFHELinalg.dot_eint_int"(%t, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> - - return %0 : !HLFHE.eint<2> -} - -// ----- - -func @single_dyn_dot(%t: tensor<4x!HLFHE.eint<2>>, %dyn: tensor<4xi3>) -> !HLFHE.eint<2> -{ - // CHECK: %[[ret:.*]] = "HLFHELinalg.dot_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> - %0 = "HLFHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> - - return %0 : !HLFHE.eint<2> -} - -// ----- - func @single_cst_add_eint_int(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2> { %cst = arith.constant 3 : i3 diff --git a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir index ed15c5c9c..27ac6801f 100644 --- a/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/HLFHE/Analysis/MANP_linalg.mlir @@ -160,6 +160,63 @@ func @apply_multi_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor // ----- +///////////////////////////////////////////////// +// HLFHELinalg.dot_eint_int +///////////////////////////////////////////////// + +func @single_cst_dot(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2> +{ + %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3> + // sqrt(1^2*1 + 2^2*1 + 3^2*1 + 4^2*1) = 5.477225575 + // CHECK: %[[V0:.*]] = "HLFHELinalg.dot_eint_int"(%[[T:.*]], %[[CST:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> + %0 = "HLFHELinalg.dot_eint_int"(%t, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> + return %0 : !HLFHE.eint<2> +} + +// ----- + +func @single_dyn_dot(%t: tensor<4x!HLFHE.eint<2>>, %dyn: tensor<4xi3>) -> !HLFHE.eint<2> +{ + // sqrt(1*(2^3-1)^2*4) = 14 + // CHECK: %[[V0:.*]] = "HLFHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> + %0 = "HLFHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> + + return %0 : !HLFHE.eint<2> +} + +// ----- + +func @single_cst_dot_after_op(%t: tensor<4x!HLFHE.eint<2>>, %i: tensor<4xi3>) -> !HLFHE.eint<2> +{ + // sqrt((2^3)^2*1) = sqrt(64) = 8 + // CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} + %0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> + + %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3> + // sqrt(1^2*64 + 2^2*64 + 3^2*64 + 4^2*64) = sqrt(1920) = 43.8178046 + // CHECK: %[[V1:.*]] = "HLFHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 44 : ui{{[[0-9]+}}} + %1 = "HLFHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> + + return %1 : !HLFHE.eint<2> +} + +// ----- + +func @single_dyn_dot_after_op(%t: tensor<4x!HLFHE.eint<2>>, %i: tensor<4xi3>) -> !HLFHE.eint<2> +{ + // sqrt((2^3)^2*1) = sqrt(64) = 8 + // CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}} + %0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>> + + // sqrt(4*(2^3-1)^2*64) = sqrt(12544) = 112 + // CHECK: %[[V1:.*]] = "HLFHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 112 : ui{{[[0-9]+}}} + %1 = "HLFHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2> + + return %1 : !HLFHE.eint<2> +} + +// ----- + ///////////////////////////////////////////////// // HLFHELinalg.matmul_ent_int /////////////////////////////////////////////////