From 6d2f853c079f2901d24dc74ea84ad8c74de70830 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 16 Feb 2022 14:46:09 +0100 Subject: [PATCH] feat: support Conv2d in MANP --- compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 109 ++++++++++++++++++ .../Dialect/FHE/FHE/Analysis/MANP_linalg.mlir | 59 ++++++++++ 2 files changed, 168 insertions(+) diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 977c4ed83..6da164264 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -894,6 +894,111 @@ static llvm::APInt getSqMANP( return result; } +static llvm::APInt getSqMANP( + mlir::concretelang::FHELinalg::Conv2dOp op, + llvm::ArrayRef *> operandMANPs) { + + mlir::RankedTensorType weightTy = + op.weight().getType().cast(); + + mlir::Type weightIntType = weightTy.getElementType(); + + // Bias is optional, so we can have both 2 or 3 operands + assert((operandMANPs.size() == 2 || operandMANPs.size() == 3) && + operandMANPs[0]->getValue().getMANP().hasValue() && + "Missing squared Minimal Arithmetic Noise Padding for encrypted " + "operand"); + + llvm::APInt inputNorm = operandMANPs[0]->getValue().getMANP().getValue(); + + mlir::arith::ConstantOp weightCstOp = + llvm::dyn_cast_or_null( + op->getOpOperand(1).get().getDefiningOp()); + mlir::DenseIntElementsAttr weightDenseVals = + weightCstOp + ? weightCstOp->getAttrOfType("value") + : nullptr; + + mlir::DenseIntElementsAttr biasDenseVals = nullptr; + mlir::Type biasIntType; + bool hasBias = operandMANPs.size() == 3; + if (hasBias) { + biasIntType = + op.bias().getType().cast().getElementType(); + mlir::arith::ConstantOp biasCstOp = + llvm::dyn_cast_or_null( + op->getOpOperand(2).get().getDefiningOp()); + biasDenseVals = + biasCstOp + ? biasCstOp->getAttrOfType("value") + : nullptr; + } + + // Initial value of the accumulator to 0, or the conservative norm of the bias + // if there is a non-const bias + llvm::APInt accNorm; + if (hasBias && biasDenseVals == nullptr) { + accNorm = conservativeIntNorm2Sq(biasIntType); + } else { + accNorm = llvm::APInt{0, 1, false}; + } + + // Weight shapes: Filter*Channel*Height*Width + uint64_t F = weightTy.getShape()[0]; + uint64_t C = weightTy.getShape()[1]; + uint64_t H = weightTy.getShape()[2]; + uint64_t W = weightTy.getShape()[3]; + if (weightDenseVals) { + // For a constant weight kernel use actual constant to calculate 2-norm + // input windows are being multiplied by a kernel and summed up + for (uint64_t f = 0; f < F; f++) { + llvm::APInt tmpNorm = accNorm; + // If there is a bias, start accumulating from its norm + if (hasBias && biasDenseVals) { + llvm::APInt cst = biasDenseVals.getFlatValue(f); + tmpNorm = APIntWidthExtendUSq(cst); + } + for (uint64_t c = 0; c < C; c++) { + for (uint64_t h = 0; h < H; h++) { + for (uint64_t w = 0; w < W; w++) { + llvm::APInt cst = + weightDenseVals.getValue({f, c, h, w}); + llvm::APInt weightNorm = APIntWidthExtendUSq(cst); + llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm); + tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); + } + } + } + accNorm = APIntUMax(accNorm, tmpNorm); + } + } else { + // For a dynamic operand conservatively assume that the value is + // the maximum for the integer width + llvm::APInt weightNorm = conservativeIntNorm2Sq(weightIntType); + // For a weight (kernel) of shape tensor, there is C*H*W + // FHE.mul_eint_int and FHE.add_eint operations for each elements of the + // result + int64_t n_mul = C * H * W; + llvm::APInt tmpNorm = llvm::APInt{1, 1, false}; + for (int64_t i = 0; i < n_mul; i++) { + llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm); + tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); + } + if (hasBias && biasDenseVals) { + llvm::APInt maxNorm = tmpNorm; + for (uint64_t f = 0; f < F; f++) { + llvm::APInt cst = biasDenseVals.getFlatValue(f); + llvm::APInt currentNorm = APIntWidthExtendUSq(cst); + currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm); + maxNorm = APIntUMax(currentNorm, maxNorm); + } + tmpNorm = maxNorm; + } + accNorm = APIntWidthExtendUAdd(accNorm, tmpNorm); + } + return accNorm; +} + struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; MANPAnalysis(mlir::MLIRContext *ctx, bool debug) @@ -973,6 +1078,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast( op)) { norm2SqEquiv = getSqMANP(concatOp, operands); + } else if (auto conv2dOp = + llvm::dyn_cast( + op)) { + norm2SqEquiv = getSqMANP(conv2dOp, operands); } // Tensor Operators // ExtractOp diff --git a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir index d15eb2fb2..b862fce8f 100644 --- a/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/Dialect/FHE/FHE/Analysis/MANP_linalg.mlir @@ -563,3 +563,62 @@ func @concat() -> tensor<3x!FHE.eint<7>> { return %9 : tensor<3x!FHE.eint<7>> } + + +///////////////////////////////////////////////// +// FHELinalg.conv2d +///////////////////////////////////////////////// + +// ----- + +func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> + %bias = arith.constant dense<[5]> : tensor<1xi7> + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> +} + +// ----- + +func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 129 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> +} + +// ----- + +func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { + %bias = arith.constant dense<[5]> : tensor<1xi3> + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 17 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> + return %0 : tensor<1x1x2x2x!FHE.eint<2>> +} + +// ----- + +func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 18 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> + return %0 : tensor<1x1x2x2x!FHE.eint<2>> +} + +// ----- + +func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> { + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 29 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> + return %0 : tensor<100x5x2x2x!FHE.eint<2>> +} \ No newline at end of file