diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 6fd26d894..2bff387d9 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -1046,29 +1046,8 @@ static llvm::APInt getSqMANP( ? weightCstOp->getAttrOfType("value") : nullptr; - mlir::DenseIntElementsAttr biasDenseVals = nullptr; - mlir::Type biasIntType; - bool hasBias = operandMANPs.size() == 3; - if (hasBias) { - biasIntType = - op.bias().getType().cast().getElementType(); - mlir::arith::ConstantOp biasCstOp = - llvm::dyn_cast_or_null( - op->getOpOperand(2).get().getDefiningOp()); - biasDenseVals = - biasCstOp - ? biasCstOp->getAttrOfType("value") - : nullptr; - } - - // Initial value of the accumulator to 0, or the conservative norm of the bias - // if there is a non-const bias - llvm::APInt accNorm; - if (hasBias && biasDenseVals == nullptr) { - accNorm = conservativeIntNorm2Sq(biasIntType); - } else { - accNorm = llvm::APInt{0, 1, false}; - } + // Initial value of the accumulator to 0 + llvm::APInt accNorm = llvm::APInt{1, 0, false}; // Weight shapes: Filter*Channel*Height*Width uint64_t F = weightTy.getShape()[0]; @@ -1080,12 +1059,8 @@ static llvm::APInt getSqMANP( // For a constant weight kernel use actual constant to calculate 2-norm // input windows are being multiplied by a kernel and summed up for (uint64_t f = 0; f < F; f++) { - llvm::APInt tmpNorm = accNorm; - // If there is a bias, start accumulating from its norm - if (hasBias && biasDenseVals) { - llvm::APInt cst = biasDenseVals.getValues()[f]; - tmpNorm = APIntWidthExtendSqForConstant(cst); - } + llvm::APInt tmpNorm = inputNorm; + for (uint64_t c = 0; c < C; c++) { for (uint64_t h = 0; h < H; h++) { for (uint64_t w = 0; w < W; w++) { @@ -1096,6 +1071,7 @@ static llvm::APInt getSqMANP( } } } + // Take the max of the 2-norm on the filter accNorm = APIntUMax(accNorm, tmpNorm); } } else { @@ -1106,23 +1082,10 @@ static llvm::APInt getSqMANP( // FHE.mul_eint_int and FHE.add_eint operations for each elements of the // result int64_t n_mul = C * H * W; - llvm::APInt tmpNorm = llvm::APInt{1, 1, false}; + llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm); for (int64_t i = 0; i < n_mul; i++) { - llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm); - tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm); + accNorm = APIntWidthExtendUAdd(mulNorm, accNorm); } - if (hasBias && biasDenseVals) { - auto biasDenseValsAP = biasDenseVals.getValues(); - llvm::APInt maxNorm = tmpNorm; - for (uint64_t f = 0; f < F; f++) { - llvm::APInt cst = biasDenseValsAP[f]; - llvm::APInt currentNorm = APIntWidthExtendSqForConstant(cst); - currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm); - maxNorm = APIntUMax(currentNorm, maxNorm); - } - tmpNorm = maxNorm; - } - accNorm = APIntWidthExtendUAdd(accNorm, tmpNorm); } return accNorm; } diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir new file mode 100644 index 000000000..ba34a0f48 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_conv2d.mlir @@ -0,0 +1,54 @@ +// RUN: concretecompiler --passes canonicalize --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s + +func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> + %bias = arith.constant dense<[5]> : tensor<1xi7> + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 4 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> +} + +// ----- + +func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { + %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 4 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> + return %0 : tensor<1x1x2x2x!FHE.eint<6>> +} + +// ----- + +func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { + %bias = arith.constant dense<[5]> : tensor<1xi3> + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 14 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> + return %0 : tensor<1x1x2x2x!FHE.eint<2>> +} + +// ----- + +func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 14 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> + return %0 : tensor<1x1x2x2x!FHE.eint<2>> +} + +// ----- + +func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> { + // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 25 : ui{{[0-9]+}} + %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ + strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> + } : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> + return %0 : tensor<100x5x2x2x!FHE.eint<2>> +} + diff --git a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir index 1bb848b1c..4d59ff786 100644 --- a/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/Analysis/MANP_linalg.mlir @@ -1026,61 +1026,3 @@ func.func @concat() -> tensor<3x!FHE.eint<7>> { } -///////////////////////////////////////////////// -// FHELinalg.conv2d -///////////////////////////////////////////////// - -// ----- - -func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> { - %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> - %bias = arith.constant dense<[5]> : tensor<1xi7> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}} - %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ - strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> - } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> - return %0 : tensor<1x1x2x2x!FHE.eint<6>> -} - -// ----- - -func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> { - %weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 128 : ui{{[0-9]+}} - %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ - strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> - } : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> - return %0 : tensor<1x1x2x2x!FHE.eint<6>> -} - -// ----- - -func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { - %bias = arith.constant dense<[5]> : tensor<1xi3> - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 15 : ui{{[0-9]+}} - %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ - strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> - } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> - return %0 : tensor<1x1x2x2x!FHE.eint<2>> -} - -// ----- - -func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 16 : ui{{[0-9]+}} - %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ - strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> - } : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> - return %0 : tensor<1x1x2x2x!FHE.eint<2>> -} - -// ----- - -func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> { - // CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 26 : ui{{[0-9]+}} - %0 = "FHELinalg.conv2d"(%input, %weight, %bias){ - strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64> - } : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> - return %0 : tensor<100x5x2x2x!FHE.eint<2>> -} -