feat: support grouped conv2d

This commit is contained in:
youben11
2022-06-28 16:54:01 +01:00
committed by Ayoub Benaissa
parent 63d84a3e4a
commit e52ccfc1a9
6 changed files with 220 additions and 30 deletions

View File

@@ -117,6 +117,9 @@ getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
mlir::SmallVector<int64_t, 2>
getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
/// Get group from the Conv2dOp if defined, or return default value
int64_t getGroupFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp);
} // namespace FHELinalg
} // namespace concretelang
} // namespace mlir

View File

@@ -937,7 +937,8 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> {
// Since there is no U64ElementsAttr, we use I64 and make sure there is no neg values during verification
OptionalAttr<I64ElementsAttr>:$padding,
OptionalAttr<I64ElementsAttr>:$strides,
OptionalAttr<I64ElementsAttr>:$dilations
OptionalAttr<I64ElementsAttr>:$dilations,
OptionalAttr<I64Attr>:$group
);
let results = (outs Type<And<[TensorOf<[FHE_EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
let hasVerifier = 1;

View File

@@ -1507,10 +1507,140 @@ getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input,
return paddedInput;
}
mlir::Value extractContiguous4DSlice(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Value input,
mlir::RankedTensorType resultType,
llvm::SmallVector<int64_t, 4> sizes,
llvm::SmallVector<int64_t, 4> offsets) {
return rewriter
.create<mlir::tensor::ExtractSliceOp>(
loc, resultType, input,
// offset
llvm::SmallVector<mlir::OpFoldResult, 4>{
rewriter.getI64IntegerAttr(offsets[0]),
rewriter.getI64IntegerAttr(offsets[1]),
rewriter.getI64IntegerAttr(offsets[2]),
rewriter.getI64IntegerAttr(offsets[3]),
},
// sizes
llvm::SmallVector<mlir::OpFoldResult, 4>{
rewriter.getI64IntegerAttr(sizes[0]),
rewriter.getI64IntegerAttr(sizes[1]),
rewriter.getI64IntegerAttr(sizes[2]),
rewriter.getI64IntegerAttr(sizes[3]),
},
// strides
llvm::SmallVector<mlir::OpFoldResult, 4>{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(1),
})
.getResult();
}
/// Create operations for grouped convolution. This will slice the input,
/// weight, and output tensors to apply separate conv2d operations.
mlir::LogicalResult
createGroupedConv2D(mlir::PatternRewriter &rewriter,
mlir::concretelang::FHELinalg::Conv2dOp &conv2dOp,
mlir::Value paddedInput, mlir::Value weight,
mlir::Value outputTensor,
mlir::DenseIntElementsAttr stridesAttr,
mlir::DenseIntElementsAttr dilationsAttr, int64_t group) {
mlir::RankedTensorType inputTy =
paddedInput.getType().cast<mlir::RankedTensorType>();
mlir::Type inputElemTy = inputTy.getElementType();
llvm::ArrayRef<int64_t> inputShape = inputTy.getShape();
llvm::SmallVector<int64_t, 4> inputSliceSizes(
{inputShape[0], inputShape[1] / group, inputShape[2], inputShape[3]});
mlir::RankedTensorType weightTy =
weight.getType().cast<mlir::RankedTensorType>();
mlir::Type weightElemTy = weightTy.getElementType();
llvm::ArrayRef<int64_t> weightShape = weightTy.getShape();
llvm::SmallVector<int64_t, 4> weightSliceSizes(
{weightShape[0] / group, weightShape[1], weightShape[2], weightShape[3]});
mlir::RankedTensorType resultTy =
conv2dOp.getResult().getType().cast<mlir::RankedTensorType>();
llvm::ArrayRef<int64_t> resultShape = resultTy.getShape();
llvm::SmallVector<int64_t, 4> sliceResultSizes = {
resultShape[0], weightSliceSizes[0], resultShape[2], resultShape[3]};
mlir::RankedTensorType sliceResultType =
mlir::RankedTensorType::get(sliceResultSizes, inputElemTy);
// slice the input, weight, and output to apply different convolutions and
// store their outputs in a single result found in `finalResult`
mlir::Value finalResult = outputTensor;
for (int g = 0; g < group; g++) {
// input[:][g * (input_C / group) : (g + 1) * (input_C / group)][:][:]
mlir::Value inputSlice = extractContiguous4DSlice(
rewriter, conv2dOp.getLoc(), paddedInput,
mlir::RankedTensorType::get(inputSliceSizes, inputElemTy),
inputSliceSizes, {0, g * inputSliceSizes[1], 0, 0});
// weight[g * (weight_F / group) : (g + 1) * (weight_F / group)][:][:][:]
mlir::Value weightSlice = extractContiguous4DSlice(
rewriter, conv2dOp.getLoc(), weight,
mlir::RankedTensorType::get(weightSliceSizes, weightElemTy),
weightSliceSizes, {g * weightSliceSizes[0], 0, 0, 0});
// bias[:][g * (weight_F / group) : (g + 1) * (weight_F / group)][:][:]
mlir::Value biasSlice = extractContiguous4DSlice(
rewriter, conv2dOp.getLoc(), outputTensor, sliceResultType,
sliceResultSizes, {0, g * sliceResultSizes[1], 0, 0});
// attributes for custom linalg named op
auto addOpAttr = rewriter.getNamedAttr(
"add", rewriter.getStringAttr(
mlir::concretelang::FHE::AddEintOp::getOperationName()));
auto mulOpAttr = rewriter.getNamedAttr(
"mul", rewriter.getStringAttr(
mlir::concretelang::FHE::MulEintIntOp::getOperationName()));
// apply conv
mlir::Value convResult =
rewriter
.create<mlir::linalg::Conv2DNchwFchwOp>(
conv2dOp.getLoc(), sliceResultType,
mlir::ValueRange{inputSlice, weightSlice}, biasSlice,
stridesAttr, dilationsAttr,
llvm::ArrayRef<mlir::NamedAttribute>({addOpAttr, mulOpAttr}))
.getResult(0);
// insert result of a single conv in the final result
finalResult =
rewriter
.create<mlir::tensor::InsertSliceOp>(
conv2dOp.getLoc(), convResult, finalResult,
llvm::SmallVector<mlir::OpFoldResult, 4>{
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(g * sliceResultSizes[1]),
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0),
},
llvm::SmallVector<mlir::OpFoldResult, 4>{
rewriter.getI64IntegerAttr(sliceResultSizes[0]),
rewriter.getI64IntegerAttr(sliceResultSizes[1]),
rewriter.getI64IntegerAttr(sliceResultSizes[2]),
rewriter.getI64IntegerAttr(sliceResultSizes[3]),
},
llvm::SmallVector<mlir::OpFoldResult, 4>{
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(1),
rewriter.getI64IntegerAttr(1),
})
.getResult();
}
rewriter.replaceOp(conv2dOp, finalResult);
return mlir::success();
}
/// This rewrite pattern transforms any instance of operators
/// `FHELinalg.conv2d` to an instance of `linalg.fhelinalg_conv_2d_nchw_fchw`.
/// The transformation consists of padding the input tensor, and initializing
/// the output tensor with bias values if any.
/// `FHELinalg.conv2d` to one or multiple instances of
/// `linalg.conv_2d_nchw_fchw`. The transformation consists of padding the input
/// tensor, and initializing the output tensor with bias values if any. Multiple
/// linalng conv operations can be generated, and their output concatenated in
/// the case of grouped convolution
struct FHELinalgConv2dToLinalgConv2d
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::Conv2dOp> {
FHELinalgConv2dToLinalgConv2d(::mlir::MLIRContext *context)
@@ -1537,6 +1667,7 @@ struct FHELinalgConv2dToLinalgConv2d
mlir::concretelang::FHELinalg::getStridesFromConv2d(conv2dOp);
mlir::SmallVector<int64_t, 2> dilationsInts =
mlir::concretelang::FHELinalg::getDilationsFromConv2d(conv2dOp);
int64_t group = mlir::concretelang::FHELinalg::getGroupFromConv2d(conv2dOp);
// Pad the input tensor according to padding.
mlir::SmallVector<int64_t, 4> lowPaddingIncludingNC = {0, 0};
@@ -1602,18 +1733,29 @@ struct FHELinalgConv2dToLinalgConv2d
auto stridesAttr = rewriter.getI64VectorAttr(stridesInts);
auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts);
auto addOpAttr = rewriter.getNamedAttr(
"add", rewriter.getStringAttr(
mlir::concretelang::FHE::AddEintOp::getOperationName()));
auto mulOpAttr = rewriter.getNamedAttr(
"mul", rewriter.getStringAttr(
mlir::concretelang::FHE::MulEintIntOp::getOperationName()));
rewriter.replaceOpWithNewOp<mlir::linalg::Conv2DNchwFchwOp>(
conv2dOp, biasInitTensor.getType(),
mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr,
dilationsAttr,
llvm::ArrayRef<mlir::NamedAttribute>({addOpAttr, mulOpAttr}));
return mlir::success();
// we can directly use linalg::Conv2DNchwFchwOp if group is equal to 1, but
// since there is no support for groups in linalg conv operations, we need
// to slice the different tensors and apply multiple convolution in case
// group is greater than 1
if (group == 1) {
auto addOpAttr = rewriter.getNamedAttr(
"add", rewriter.getStringAttr(
mlir::concretelang::FHE::AddEintOp::getOperationName()));
auto mulOpAttr = rewriter.getNamedAttr(
"mul",
rewriter.getStringAttr(
mlir::concretelang::FHE::MulEintIntOp::getOperationName()));
rewriter.replaceOpWithNewOp<mlir::linalg::Conv2DNchwFchwOp>(
conv2dOp, biasInitTensor.getType(),
mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr,
dilationsAttr,
llvm::ArrayRef<mlir::NamedAttribute>({addOpAttr, mulOpAttr}));
return mlir::success();
}
return createGroupedConv2D(rewriter, conv2dOp, paddedInput, weight,
biasInitTensor, stridesAttr, dilationsAttr,
group);
};
};

View File

@@ -808,6 +808,13 @@ getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
return dilationsInts;
}
int64_t getGroupFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) {
llvm::Optional<uint64_t> optionalGroup = convOp.group();
if (optionalGroup.hasValue())
return optionalGroup.getValue();
return 1;
}
/// Verify the Conv2d shapes, attributes, and expected output dimensions
mlir::LogicalResult Conv2dOp::verify() {
auto inputTy =
@@ -920,6 +927,11 @@ mlir::LogicalResult Conv2dOp::verify() {
}
}
}
int64_t group = getGroupFromConv2d(*this);
if (group < 1) {
this->emitOpError() << "group must be strictly positif, but got " << group;
return mlir::failure();
}
// Extracting dimensions
int64_t inputN = inputShape[0], inputC = inputShape[1],
@@ -960,10 +972,15 @@ mlir::LogicalResult Conv2dOp::verify() {
<< inputN << ") but got " << resultN;
return mlir::failure();
}
if (inputC != weightC) {
this->emitOpError() << "expected number of channels in weight to be equal "
"to number of channels in input ("
<< inputC << ") but got " << weightC;
if (weightC != inputC / group) {
this->emitOpError()
<< "expected number of channels in weight to be equal to "
<< inputC / group << " (input_channels / group) but got " << weightC;
return mlir::failure();
}
if (weightF % group != 0) {
this->emitOpError() << "expected number of feature maps (" << weightF
<< ") to be a multiple of group (" << group << ")";
return mlir::failure();
}
if (weightF != resultC) {

View File

@@ -213,6 +213,41 @@ func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x
// -----
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op group must be strictly positif, but got 0}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){group = 0 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected number of channels in weight to be equal to 1 (input_channels / group) but got 3}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){group = 3 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x1x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected number of feature maps (4) to be a multiple of group (3)}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){group = 3 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x1x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<6x1x14x14xi3>, %bias: tensor<6xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected number of output channels to be equal to the number of filters (6) but got 4}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){group = 3 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<6x1x14x14xi3>, tensor<6xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 8 but got 15}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0 ,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
@@ -231,14 +266,6 @@ func.func @conv2d(%input: tensor<101x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x
// -----
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x4x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected number of channels in weight to be equal to number of channels in input (3) but got 4}}
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x4x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}
// -----
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi4>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// expected-error @+1 {{'FHELinalg.conv2d' op expected weight element type to have width 3 but got 4}}

View File

@@ -435,11 +435,11 @@ func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>,
/////////////////////////////////////////////////
// CHECK: func.func @conv2d(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>, %[[ARG2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) {dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) {dilations = dense<1> : tensor<2xi64>, group = 1 : i64, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>>
// CHECK-NEXT: }
func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>, group = 1 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}