enhance(compiler): Remove the restriction of analysis on the HLFHELinalg.dot

While the manp analysis wasn't handle tensor the dot operation restrict the operands to come from a block argument. Since the tensor are handled in the manp pass this restriction has no more meaning.
This commit is contained in:
Quentin Bourgerie
2021-11-25 10:53:17 +01:00
parent c54af9b550
commit bb1add2a6f
3 changed files with 74 additions and 32 deletions

View File

@@ -212,7 +212,8 @@ static std::string APIntToStringValUnsigned(const llvm::APInt &i) {
// Calculates the square of the 2-norm of a tensor initialized with a
// dense matrix of constant, signless integers. Aborts if the value
// type or initialization of of `cstOp` is incorrect.
static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) {
static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp,
llvm::APInt eNorm) {
mlir::DenseIntElementsAttr denseVals =
cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
@@ -230,8 +231,9 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) {
llvm::APInt accu{1, 0, false};
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
llvm::APInt valSq = APIntWidthExtendUSq(val);
accu = APIntWidthExtendUAdd(accu, valSq);
llvm::APInt valSqNorm = APIntWidthExtendUSq(val);
llvm::APInt mulSqNorm = APIntWidthExtendUMul(valSqNorm, eNorm);
accu = APIntWidthExtendUAdd(accu, mulSqNorm);
}
return accu;
@@ -241,7 +243,8 @@ static llvm::APInt denseCstTensorNorm2Sq(mlir::arith::ConstantOp cstOp) {
// integers by conservatively assuming that the dynamic values are the
// maximum for the integer width. Aborts if the tensor type `tTy` is
// incorrect.
static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy,
llvm::APInt eNorm) {
assert(tTy && tTy.getElementType().isSignlessInteger() &&
tTy.hasStaticShape() && tTy.getRank() == 1 &&
"Plaintext operand must be a statically shaped 1D tensor of integers");
@@ -254,6 +257,7 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
llvm::APInt maxVal = APInt::getMaxValue(elWidth);
llvm::APInt maxValSq = APIntWidthExtendUSq(maxVal);
llvm::APInt maxMulSqNorm = APIntWidthExtendUMul(maxValSq, eNorm);
// Calculate number of bits for APInt to store number of elements
uint64_t nElts = (uint64_t)tTy.getNumElements();
@@ -262,7 +266,7 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
llvm::APInt nEltsAP{nEltsBits, nElts, false};
return APIntWidthExtendUMul(maxValSq, nEltsAP);
return APIntWidthExtendUMul(maxMulSqNorm, nEltsAP);
}
// Calculates the squared Minimal Arithmetic Noise Padding of an
@@ -270,9 +274,12 @@ static llvm::APInt denseDynTensorNorm2Sq(mlir::TensorType tTy) {
static llvm::APInt getSqMANP(
mlir::zamalang::HLFHELinalg::Dot op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
assert(op->getOpOperand(0).get().isa<mlir::BlockArgument>() &&
"Only dot operations with tensors that are function arguments are "
"currently supported");
assert(operandMANPs.size() == 2 &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operands");
llvm::APInt eNorm = operandMANPs[0]->getValue().getMANP().getValue();
mlir::arith::ConstantOp cstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
@@ -281,7 +288,7 @@ static llvm::APInt getSqMANP(
if (cstOp) {
// Dot product between a vector of encrypted integers and a vector
// of plaintext constants -> return 2-norm of constant vector
return denseCstTensorNorm2Sq(cstOp);
return denseCstTensorNorm2Sq(cstOp, eNorm);
} else {
// Dot product between a vector of encrypted integers and a vector
// of dynamic plaintext values -> conservatively assume that all
@@ -292,7 +299,7 @@ static llvm::APInt getSqMANP(
.getType()
.dyn_cast_or_null<mlir::TensorType>();
return denseDynTensorNorm2Sq(tTy);
return denseDynTensorNorm2Sq(tTy, eNorm);
}
}

View File

@@ -10,28 +10,6 @@ func @single_zero() -> !HLFHE.eint<2>
// -----
func @single_cst_dot(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2>
{
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3>
// CHECK: %[[ret:.*]] = "HLFHELinalg.dot_eint_int"(%[[op0:.*]], %[[op1:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
%0 = "HLFHELinalg.dot_eint_int"(%t, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
return %0 : !HLFHE.eint<2>
}
// -----
func @single_dyn_dot(%t: tensor<4x!HLFHE.eint<2>>, %dyn: tensor<4xi3>) -> !HLFHE.eint<2>
{
// CHECK: %[[ret:.*]] = "HLFHELinalg.dot_eint_int"([[op0:.*]], %[[op1:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
%0 = "HLFHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
return %0 : !HLFHE.eint<2>
}
// -----
func @single_cst_add_eint_int(%e: !HLFHE.eint<2>) -> !HLFHE.eint<2>
{
%cst = arith.constant 3 : i3

View File

@@ -160,6 +160,63 @@ func @apply_multi_lookup_table_after_op(%t: tensor<8x!HLFHE.eint<2>>, %i: tensor
// -----
/////////////////////////////////////////////////
// HLFHELinalg.dot_eint_int
/////////////////////////////////////////////////
func @single_cst_dot(%t: tensor<4x!HLFHE.eint<2>>) -> !HLFHE.eint<2>
{
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3>
// sqrt(1^2*1 + 2^2*1 + 3^2*1 + 4^2*1) = 5.477225575
// CHECK: %[[V0:.*]] = "HLFHELinalg.dot_eint_int"(%[[T:.*]], %[[CST:.*]]) {MANP = 6 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
%0 = "HLFHELinalg.dot_eint_int"(%t, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
return %0 : !HLFHE.eint<2>
}
// -----
func @single_dyn_dot(%t: tensor<4x!HLFHE.eint<2>>, %dyn: tensor<4xi3>) -> !HLFHE.eint<2>
{
// sqrt(1*(2^3-1)^2*4) = 14
// CHECK: %[[V0:.*]] = "HLFHELinalg.dot_eint_int"([[T:.*]], %[[DYN:.*]]) {MANP = 14 : ui{{[[0-9]+}}} : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
%0 = "HLFHELinalg.dot_eint_int"(%t, %dyn) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
return %0 : !HLFHE.eint<2>
}
// -----
func @single_cst_dot_after_op(%t: tensor<4x!HLFHE.eint<2>>, %i: tensor<4xi3>) -> !HLFHE.eint<2>
{
// sqrt((2^3)^2*1) = sqrt(64) = 8
// CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}}
%0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>>
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi3>
// sqrt(1^2*64 + 2^2*64 + 3^2*64 + 4^2*64) = sqrt(1920) = 43.8178046
// CHECK: %[[V1:.*]] = "HLFHELinalg.dot_eint_int"(%[[V0]], %[[CST:.*]]) {MANP = 44 : ui{{[[0-9]+}}}
%1 = "HLFHELinalg.dot_eint_int"(%0, %cst) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
return %1 : !HLFHE.eint<2>
}
// -----
func @single_dyn_dot_after_op(%t: tensor<4x!HLFHE.eint<2>>, %i: tensor<4xi3>) -> !HLFHE.eint<2>
{
// sqrt((2^3)^2*1) = sqrt(64) = 8
// CHECK: %[[V0:.*]] = "HLFHELinalg.mul_eint_int"([[T:.*]], %[[I:.*]]) {MANP = 8 : ui{{[0-9]+}}}
%0 = "HLFHELinalg.mul_eint_int"(%t, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> tensor<4x!HLFHE.eint<2>>
// sqrt(4*(2^3-1)^2*64) = sqrt(12544) = 112
// CHECK: %[[V1:.*]] = "HLFHELinalg.dot_eint_int"(%[[V0]], %[[I]]) {MANP = 112 : ui{{[[0-9]+}}}
%1 = "HLFHELinalg.dot_eint_int"(%0, %i) : (tensor<4x!HLFHE.eint<2>>, tensor<4xi3>) -> !HLFHE.eint<2>
return %1 : !HLFHE.eint<2>
}
// -----
/////////////////////////////////////////////////
// HLFHELinalg.matmul_ent_int
/////////////////////////////////////////////////