mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: support Conv2d in MANP
This commit is contained in:
@@ -894,6 +894,111 @@ static llvm::APInt getSqMANP(
|
||||
return result;
|
||||
}
|
||||
|
||||
static llvm::APInt getSqMANP(
|
||||
mlir::concretelang::FHELinalg::Conv2dOp op,
|
||||
llvm::ArrayRef<mlir::LatticeElement<MANPLatticeValue> *> operandMANPs) {
|
||||
|
||||
mlir::RankedTensorType weightTy =
|
||||
op.weight().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
mlir::Type weightIntType = weightTy.getElementType();
|
||||
|
||||
// Bias is optional, so we can have both 2 or 3 operands
|
||||
assert((operandMANPs.size() == 2 || operandMANPs.size() == 3) &&
|
||||
operandMANPs[0]->getValue().getMANP().hasValue() &&
|
||||
"Missing squared Minimal Arithmetic Noise Padding for encrypted "
|
||||
"operand");
|
||||
|
||||
llvm::APInt inputNorm = operandMANPs[0]->getValue().getMANP().getValue();
|
||||
|
||||
mlir::arith::ConstantOp weightCstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
op->getOpOperand(1).get().getDefiningOp());
|
||||
mlir::DenseIntElementsAttr weightDenseVals =
|
||||
weightCstOp
|
||||
? weightCstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
|
||||
mlir::DenseIntElementsAttr biasDenseVals = nullptr;
|
||||
mlir::Type biasIntType;
|
||||
bool hasBias = operandMANPs.size() == 3;
|
||||
if (hasBias) {
|
||||
biasIntType =
|
||||
op.bias().getType().cast<mlir::RankedTensorType>().getElementType();
|
||||
mlir::arith::ConstantOp biasCstOp =
|
||||
llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
|
||||
op->getOpOperand(2).get().getDefiningOp());
|
||||
biasDenseVals =
|
||||
biasCstOp
|
||||
? biasCstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value")
|
||||
: nullptr;
|
||||
}
|
||||
|
||||
// Initial value of the accumulator to 0, or the conservative norm of the bias
|
||||
// if there is a non-const bias
|
||||
llvm::APInt accNorm;
|
||||
if (hasBias && biasDenseVals == nullptr) {
|
||||
accNorm = conservativeIntNorm2Sq(biasIntType);
|
||||
} else {
|
||||
accNorm = llvm::APInt{0, 1, false};
|
||||
}
|
||||
|
||||
// Weight shapes: Filter*Channel*Height*Width
|
||||
uint64_t F = weightTy.getShape()[0];
|
||||
uint64_t C = weightTy.getShape()[1];
|
||||
uint64_t H = weightTy.getShape()[2];
|
||||
uint64_t W = weightTy.getShape()[3];
|
||||
if (weightDenseVals) {
|
||||
// For a constant weight kernel use actual constant to calculate 2-norm
|
||||
// input windows are being multiplied by a kernel and summed up
|
||||
for (uint64_t f = 0; f < F; f++) {
|
||||
llvm::APInt tmpNorm = accNorm;
|
||||
// If there is a bias, start accumulating from its norm
|
||||
if (hasBias && biasDenseVals) {
|
||||
llvm::APInt cst = biasDenseVals.getFlatValue<llvm::APInt>(f);
|
||||
tmpNorm = APIntWidthExtendUSq(cst);
|
||||
}
|
||||
for (uint64_t c = 0; c < C; c++) {
|
||||
for (uint64_t h = 0; h < H; h++) {
|
||||
for (uint64_t w = 0; w < W; w++) {
|
||||
llvm::APInt cst =
|
||||
weightDenseVals.getValue<llvm::APInt>({f, c, h, w});
|
||||
llvm::APInt weightNorm = APIntWidthExtendUSq(cst);
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
|
||||
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
|
||||
}
|
||||
}
|
||||
}
|
||||
accNorm = APIntUMax(accNorm, tmpNorm);
|
||||
}
|
||||
} else {
|
||||
// For a dynamic operand conservatively assume that the value is
|
||||
// the maximum for the integer width
|
||||
llvm::APInt weightNorm = conservativeIntNorm2Sq(weightIntType);
|
||||
// For a weight (kernel) of shape tensor<FxCxHxW>, there is C*H*W
|
||||
// FHE.mul_eint_int and FHE.add_eint operations for each elements of the
|
||||
// result
|
||||
int64_t n_mul = C * H * W;
|
||||
llvm::APInt tmpNorm = llvm::APInt{1, 1, false};
|
||||
for (int64_t i = 0; i < n_mul; i++) {
|
||||
llvm::APInt mulNorm = APIntWidthExtendUMul(inputNorm, weightNorm);
|
||||
tmpNorm = APIntWidthExtendUAdd(mulNorm, tmpNorm);
|
||||
}
|
||||
if (hasBias && biasDenseVals) {
|
||||
llvm::APInt maxNorm = tmpNorm;
|
||||
for (uint64_t f = 0; f < F; f++) {
|
||||
llvm::APInt cst = biasDenseVals.getFlatValue<llvm::APInt>(f);
|
||||
llvm::APInt currentNorm = APIntWidthExtendUSq(cst);
|
||||
currentNorm = APIntWidthExtendUAdd(currentNorm, tmpNorm);
|
||||
maxNorm = APIntUMax(currentNorm, maxNorm);
|
||||
}
|
||||
tmpNorm = maxNorm;
|
||||
}
|
||||
accNorm = APIntWidthExtendUAdd(accNorm, tmpNorm);
|
||||
}
|
||||
return accNorm;
|
||||
}
|
||||
|
||||
struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
using ForwardDataFlowAnalysis<MANPLatticeValue>::ForwardDataFlowAnalysis;
|
||||
MANPAnalysis(mlir::MLIRContext *ctx, bool debug)
|
||||
@@ -973,6 +1078,10 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis<MANPLatticeValue> {
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::ConcatOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(concatOp, operands);
|
||||
} else if (auto conv2dOp =
|
||||
llvm::dyn_cast<mlir::concretelang::FHELinalg::Conv2dOp>(
|
||||
op)) {
|
||||
norm2SqEquiv = getSqMANP(conv2dOp, operands);
|
||||
}
|
||||
// Tensor Operators
|
||||
// ExtractOp
|
||||
|
||||
@@ -563,3 +563,62 @@ func @concat() -> tensor<3x!FHE.eint<7>> {
|
||||
|
||||
return %9 : tensor<3x!FHE.eint<7>>
|
||||
}
|
||||
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// FHELinalg.conv2d
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d_const_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<6>>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
|
||||
%bias = arith.constant dense<[5]> : tensor<1xi7>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 6 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d_const_weight(%input: tensor<1x1x4x4x!FHE.eint<6>>, %bias : tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>> {
|
||||
%weight = arith.constant dense<[[[[1, 2], [2, 1]]]]> : tensor<1x1x2x2xi7>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 129 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<6>>, tensor<1x1x2x2xi7>, tensor<1xi7>) -> tensor<1x1x2x2x!FHE.eint<6>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<6>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
|
||||
%bias = arith.constant dense<[5]> : tensor<1xi3>
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 17 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d_weight_const_bias(%input: tensor<1x1x4x4x!FHE.eint<2>>, %weight: tensor<1x1x2x2xi3>, %bias : tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>> {
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 18 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<1x1x4x4x!FHE.eint<2>>, tensor<1x1x2x2xi3>, tensor<1xi3>) -> tensor<1x1x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<1x1x2x2x!FHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d_batched_multiple_channels(%input: tensor<100x3x4x4x!FHE.eint<2>>, %weight: tensor<5x3x2x2xi3>, %bias : tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>> {
|
||||
// CHECK: %[[V1:.*]] = "FHELinalg.conv2d"(%[[A0:.*]], %[[A1:.*]], %[[A2:.*]]) {MANP = 29 : ui{{[0-9]+}}
|
||||
%0 = "FHELinalg.conv2d"(%input, %weight, %bias){
|
||||
strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0,0,0]> : tensor<4xi64>
|
||||
} : (tensor<100x3x4x4x!FHE.eint<2>>, tensor<5x3x2x2xi3>, tensor<5xi3>) -> tensor<100x5x2x2x!FHE.eint<2>>
|
||||
return %0 : tensor<100x5x2x2x!FHE.eint<2>>
|
||||
}
|
||||
Reference in New Issue
Block a user