mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: Fixing the MANP computation for conv2d (close #883)
This commit is contained in:
@@ -1046,29 +1046,8 @@ static llvm::APInt getSqMANP(
|
||||
? weightCstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
|
||||
mlir::DenseIntElementsAttr biasDenseVals = nullptr;
|
||||
mlir::Type biasIntType;
|
||||
bool hasBias = operandMANPs.size() == 3;
|
||||
if (hasBias) {
|
||||
biasIntType =
|
||||
op.bias().getType().cast<mlir::RankedTensorType>().getElementType();
|
||||
mlir::arith::ConstantOp biasCstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
op->getOpOperand(2).get().getDefiningOp());
|
||||
biasDenseVals =
|
||||
biasCstOp
|
||||
? biasCstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
// Initial value of the accumulator to 0, or the conservative norm of the bias
|
||||
// if there is a non-const bias
|
||||
llvm::APInt accNorm;
|
||||
if (hasBias && biasDenseVals == nullptr) {
|
||||
accNorm = conservativeIntNorm2Sq(biasIntType);
|
||||
} else {
|
||||
accNorm = llvm::APInt{0, 1, false};
|
||||
}
|
||||
// Initial value of the accumulator to 0
|
||||
llvm::APInt accNorm = llvm::APInt{1, 0, false};
|
||||
|
||||
// Weight shapes: Filter*Channel*Height*Width
|
||||
uint64_t F = weightTy.getShape()[0];
|
||||
@@ -1080,12 +1059,8 @@ static llvm::APInt getSqMANP(
|
||||
// For a constant weight kernel use actual constant to calculate 2-norm
|
||||
// input windows are being multiplied by a kernel and summed up
|
||||
for (uint64_t f = 0; f < F; f++) {
|
||||
llvm::APInt tmpNorm = accNorm;
|
||||
// If there is a bias, start accumulating from its norm
|
||||
if (hasBias && biasDenseVals) {
|
||||
llvm::APInt cst = biasDenseVals.getValues<llvm::APInt>()[f];
|
||||
tmpNorm = APIntWidthExtendSqForConstant(cst);
|
||||
}
|
||||
llvm::APInt tmpNorm = inputNorm;
|
||||
|
||||
for (uint64_t c = 0; c < C; c++) {
|
||||
for (uint64_t h = 0; h < H; h++) {
|
||||
for (uint64_t w = 0; w < W; w++) {
|
||||
@@ -1096,6 +1071,7 @@ static llvm::APInt getSqMANP(
|
||||
}
|
||||
}
|
||||
}
|
||||
// Take the max of the 2-norm on the filter
|
||||
accNorm = APIntUMax(accNorm, tmpNorm);
|
||||
}
|
||||
} else {
|
||||
@@ -1106,23 +1082,10 @@ static llvm::APInt getSqMANP(
|
||||
// FHE.mul_eint_int and FHE.add_eint operations for each elements of the
|
||||
// result
|
||||
int64_t n_mul = C * H * W;
|
||||
llvm::APInt tmpNorm = llvm::APInt{1, 1, false};
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
|
||||
for (int64_t i = 0; i < n_mul; i++) {
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
|
||||
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
|
||||
accNorm = APIntWidthExtendUAdd(mulNorm, accNorm);
|
||||
}
|
||||
if (hasBias && biasDenseVals) {
|
||||
auto biasDenseValsAP = biasDenseVals.getValues<llvm::APInt>();
|
||||
llvm::APInt maxNorm = tmpNorm;
|
||||
for (uint64_t f = 0; f < F; f++) {
|
||||
llvm::APInt cst = biasDenseValsAP[f];
|
||||
llvm::APInt currentNorm = APIntWidthExtendSqForConstant(cst);
|
||||
currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm);
|
||||
maxNorm = APIntUMax(currentNorm, maxNorm);
|
||||
}
|
||||
tmpNorm = maxNorm;
|
||||
}
|
||||
accNorm = APIntWidthExtendUAdd(accNorm, tmpNorm);
|
||||
}
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
// RUN: concretecompiler --passes canonicalize --passes MANP --passes ConcreteOptimizer --action=dump-fhe --split-input-file %s 2>&1 | FileCheck %s
|
||||
|
||||
func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
|
||||
%bias = arith.constant dense<[5]> : tensor<1xi7>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 4 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 4 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
|
||||
%bias = arith.constant dense<[5]> : tensor<1xi3>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 14 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 14 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> {
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 25 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<100x5x2x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
@@ -1026,61 +1026,3 @@ func.func @concat() -> tensor<3x!FHE.eint<7>> {
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.conv2d
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
|
||||
%bias = arith.constant dense<[5]> : tensor<1xi7>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 128 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
|
||||
%bias = arith.constant dense<[5]> : tensor<1xi3>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 15 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 16 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> {
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 26 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<100x5x2x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user