feat: support Conv2d in MANP

This commit is contained in:
youben11
2022-02-16 14:46:09 +01:00
committed by Ayoub Benaissa
parent 3668b2d73a
commit 6d2f853c07
2 changed files with 168 additions and 0 deletions

View File

@@ -894,6 +894,111 @@ static llvm::APInt getSqMANP(
return result;
}
static llvm::APInt getSqMANP(
mlir::concretelang::FHELinalg::Conv2dOp op,
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
mlir::RankedTensorType weightTy =
op.weight().getType().cast<mlir::RankedTensorType>();
mlir::Type weightIntType = weightTy.getElementType();
// Bias is optional, so we can have both 2 or 3 operands
assert((operandMANPs.size() == 2 || operandMANPs.size() == 3) &&
operandMANPs[0]->getValue().getMANP().hasValue() &&
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
"operand");
llvm::APInt inputNorm = operandMANPs[0]->getValue().getMANP().getValue();
mlir::arith::ConstantOp weightCstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
op->getOpOperand(1).get().getDefiningOp());
mlir::DenseIntElementsAttr weightDenseVals =
weightCstOp
? weightCstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
: nullptr;
mlir::DenseIntElementsAttr biasDenseVals = nullptr;
mlir::Type biasIntType;
bool hasBias = operandMANPs.size() == 3;
if (hasBias) {
biasIntType =
op.bias().getType().cast<mlir::RankedTensorType>().getElementType();
mlir::arith::ConstantOp biasCstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
op->getOpOperand(2).get().getDefiningOp());
biasDenseVals =
biasCstOp
? biasCstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
: nullptr;
}
// Initial value of the accumulator to 0, or the conservative norm of the bias
// if there is a non-const bias
llvm::APInt accNorm;
if (hasBias && biasDenseVals == nullptr) {
accNorm = conservativeIntNorm2Sq(biasIntType);
} else {
accNorm = llvm::APInt{0, 1, false};
}
// Weight shapes: Filter*Channel*Height*Width
uint64_t F = weightTy.getShape()[0];
uint64_t C = weightTy.getShape()[1];
uint64_t H = weightTy.getShape()[2];
uint64_t W = weightTy.getShape()[3];
if (weightDenseVals) {
// For a constant weight kernel use actual constant to calculate 2-norm
// input windows are being multiplied by a kernel and summed up
for (uint64_t f = 0; f < F; f++) {
llvm::APInt tmpNorm = accNorm;
// If there is a bias, start accumulating from its norm
if (hasBias && biasDenseVals) {
llvm::APInt cst = biasDenseVals.getFlatValue<llvm::APInt>(f);
tmpNorm = APIntWidthExtendUSq(cst);
}
for (uint64_t c = 0; c < C; c++) {
for (uint64_t h = 0; h < H; h++) {
for (uint64_t w = 0; w < W; w++) {
llvm::APInt cst =
weightDenseVals.getValue<llvm::APInt>({f, c, h, w});
llvm::APInt weightNorm = APIntWidthExtendUSq(cst);
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
}
}
}
accNorm = APIntUMax(accNorm, tmpNorm);
}
} else {
// For a dynamic operand conservatively assume that the value is
// the maximum for the integer width
llvm::APInt weightNorm = conservativeIntNorm2Sq(weightIntType);
// For a weight (kernel) of shape tensor<FxCxHxW>, there is C*H*W
// FHE.mul_eint_int and FHE.add_eint operations for each elements of the
// result
int64_t n_mul = C * H * W;
llvm::APInt tmpNorm = llvm::APInt{1, 1, false};
for (int64_t i = 0; i < n_mul; i++) {
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
}
if (hasBias && biasDenseVals) {
llvm::APInt maxNorm = tmpNorm;
for (uint64_t f = 0; f < F; f++) {
llvm::APInt cst = biasDenseVals.getFlatValue<llvm::APInt>(f);
llvm::APInt currentNorm = APIntWidthExtendUSq(cst);
currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm);
maxNorm = APIntUMax(currentNorm, maxNorm);
}
tmpNorm = maxNorm;
}
accNorm = APIntWidthExtendUAdd(accNorm, tmpNorm);
}
return accNorm;
}
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
@@ -973,6 +1078,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
llvm::dyn_cast<mlir::concretelang::FHELinalg::ConcatOp>(
op)) {
norm2SqEquiv = getSqMANP(concatOp, operands);
} else if (auto conv2dOp =
llvm::dyn_cast<mlir::concretelang::FHELinalg::Conv2dOp>(
op)) {
norm2SqEquiv = getSqMANP(conv2dOp, operands);
}
// Tensor Operators
// ExtractOp

View File

@@ -563,3 +563,62 @@ func @concat() -> tensor<3x!FHE.eint<7>> {
return %9 : tensor<3x!FHE.eint<7>>
}
/////////////////////////////////////////////////
// FHELinalg.conv2d
/////////////////////////////////////////////////
// -----
func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
%bias = arith.constant dense<[5]> : tensor<1xi7>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
// -----
func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 129 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
// -----
func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
%bias = arith.constant dense<[5]> : tensor<1xi3>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 17 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
}
// -----
func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 18 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
}
// -----
func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> {
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 29 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>>
return %0 : tensor<100x5x2x2x!FHE.eint<2>>
}