fix: Fixing the MANP computation for conv2d (close #883)

This commit is contained in:
Quentin Bourgerie
2023-01-13 15:47:22 +01:00
parent 0329d4fc2d
commit d1ddd60a23
3 changed files with 61 additions and 102 deletions

View File

@@ -1046,29 +1046,8 @@ static llvm::APInt getSqMANP(
? weightCstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
: nullptr;
mlir::DenseIntElementsAttr biasDenseVals = nullptr;
mlir::Type biasIntType;
bool hasBias = operandMANPs.size() == 3;
if (hasBias) {
biasIntType =
op.bias().getType().cast<mlir::RankedTensorType>().getElementType();
mlir::arith::ConstantOp biasCstOp =
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
op->getOpOperand(2).get().getDefiningOp());
biasDenseVals =
biasCstOp
? biasCstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
: nullptr;
}
// Initial value of the accumulator to 0, or the conservative norm of the bias
// if there is a non-const bias
llvm::APInt accNorm;
if (hasBias && biasDenseVals == nullptr) {
accNorm = conservativeIntNorm2Sq(biasIntType);
} else {
accNorm = llvm::APInt{0, 1, false};
}
// Initial value of the accumulator to 0
llvm::APInt accNorm = llvm::APInt{1, 0, false};
// Weight shapes: Filter*Channel*Height*Width
uint64_t F = weightTy.getShape()[0];
@@ -1080,12 +1059,8 @@ static llvm::APInt getSqMANP(
// For a constant weight kernel use actual constant to calculate 2-norm
// input windows are being multiplied by a kernel and summed up
for (uint64_t f = 0; f < F; f++) {
llvm::APInt tmpNorm = accNorm;
// If there is a bias, start accumulating from its norm
if (hasBias && biasDenseVals) {
llvm::APInt cst = biasDenseVals.getValues<llvm::APInt>()[f];
tmpNorm = APIntWidthExtendSqForConstant(cst);
}
llvm::APInt tmpNorm = inputNorm;
for (uint64_t c = 0; c < C; c++) {
for (uint64_t h = 0; h < H; h++) {
for (uint64_t w = 0; w < W; w++) {
@@ -1096,6 +1071,7 @@ static llvm::APInt getSqMANP(
}
}
}
// Take the max of the 2-norm on the filter
accNorm = APIntUMax(accNorm, tmpNorm);
}
} else {
@@ -1106,23 +1082,10 @@ static llvm::APInt getSqMANP(
// FHE.mul_eint_int and FHE.add_eint operations for each elements of the
// result
int64_t n_mul = C * H * W;
llvm::APInt tmpNorm = llvm::APInt{1, 1, false};
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
for (int64_t i = 0; i < n_mul; i++) {
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
accNorm = APIntWidthExtendUAdd(mulNorm, accNorm);
}
if (hasBias && biasDenseVals) {
auto biasDenseValsAP = biasDenseVals.getValues<llvm::APInt>();
llvm::APInt maxNorm = tmpNorm;
for (uint64_t f = 0; f < F; f++) {
llvm::APInt cst = biasDenseValsAP[f];
llvm::APInt currentNorm = APIntWidthExtendSqForConstant(cst);
currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm);
maxNorm = APIntUMax(currentNorm, maxNorm);
}
tmpNorm = maxNorm;
}
accNorm = APIntWidthExtendUAdd(accNorm, tmpNorm);
}
return accNorm;
}

View File

@@ -0,0 +1,54 @@
// RUN: concretecompiler --passes canonicalize --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
%bias = arith.constant dense<[5]> : tensor<1xi7>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 4 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
// -----
func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 4 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
// -----
func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
%bias = arith.constant dense<[5]> : tensor<1xi3>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 14 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
}
// -----
func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 14 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
}
// -----
func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> {
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 25 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>>
return %0 : tensor<100x5x2x2x!FHE.eint<2>>
}

View File

@@ -1026,61 +1026,3 @@ func.func @concat() -> tensor<3x!FHE.eint<7>> {
}
/////////////////////////////////////////////////
// FHELinalg.conv2d
/////////////////////////////////////////////////
// -----
func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
%bias = arith.constant dense<[5]> : tensor<1xi7>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
// -----
func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 128 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
}
// -----
func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
%bias = arith.constant dense<[5]> : tensor<1xi3>
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 15 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
}
// -----
func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 16 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
}
// -----
func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> {
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 26 : ui{{[0-9]+}}
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
} : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>>
return %0 : tensor<100x5x2x2x!FHE.eint<2>>
}