From e52ccfc1a904a8711b1473ec96cca1f1ea3b92e9 Mon Sep 17 00:00:00 2001 From: youben11 Date: Tue, 28 Jun 2022 16:54:01 +0100 Subject: [PATCH] feat: support grouped conv2d --- .../Dialect/FHELinalg/IR/FHELinalgOps.h | 3 + .../Dialect/FHELinalg/IR/FHELinalgOps.td | 3 +- .../TensorOpsToLinalg.cpp | 172 ++++++++++++++++-- .../lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp | 25 ++- .../Dialect/FHELinalg/ops.invalid.mlir | 43 ++++- .../check_tests/Dialect/FHELinalg/ops.mlir | 4 +- 6 files changed, 220 insertions(+), 30 deletions(-) diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h index fee8c6094..cfe8f3198 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h @@ -117,6 +117,9 @@ getStridesFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp); mlir::SmallVector getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp); +/// Get group from the Conv2dOp if defined, or return default value +int64_t getGroupFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp); + } // namespace FHELinalg } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td index 447c77ca0..d1a5a1983 100644 --- a/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td +++ b/compiler/include/concretelang/Dialect/FHELinalg/IR/FHELinalgOps.td @@ -937,7 +937,8 @@ def FHELinalg_Conv2dOp : FHELinalg_Op<"conv2d", []> { // Since there is no U64ElementsAttr, we use I64 and make sure there is no neg values during verification OptionalAttr:$padding, OptionalAttr:$strides, - OptionalAttr:$dilations + OptionalAttr:$dilations, + OptionalAttr:$group ); let results = (outs Type.predicate, HasStaticShapePred]>>); let hasVerifier = 1; diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 04f984f5b..9047ad371 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -1507,10 +1507,140 @@ getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input, return paddedInput; } +mlir::Value extractContiguous4DSlice(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Value input, + mlir::RankedTensorType resultType, + llvm::SmallVector sizes, + llvm::SmallVector offsets) { + return rewriter + .create( + loc, resultType, input, + // offset + llvm::SmallVector{ + rewriter.getI64IntegerAttr(offsets[0]), + rewriter.getI64IntegerAttr(offsets[1]), + rewriter.getI64IntegerAttr(offsets[2]), + rewriter.getI64IntegerAttr(offsets[3]), + }, + // sizes + llvm::SmallVector{ + rewriter.getI64IntegerAttr(sizes[0]), + rewriter.getI64IntegerAttr(sizes[1]), + rewriter.getI64IntegerAttr(sizes[2]), + rewriter.getI64IntegerAttr(sizes[3]), + }, + // strides + llvm::SmallVector{ + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(1), + }) + .getResult(); +} + +/// Create operations for grouped convolution. This will slice the input, +/// weight, and output tensors to apply separate conv2d operations. +mlir::LogicalResult +createGroupedConv2D(mlir::PatternRewriter &rewriter, + mlir::concretelang::FHELinalg::Conv2dOp &conv2dOp, + mlir::Value paddedInput, mlir::Value weight, + mlir::Value outputTensor, + mlir::DenseIntElementsAttr stridesAttr, + mlir::DenseIntElementsAttr dilationsAttr, int64_t group) { + + mlir::RankedTensorType inputTy = + paddedInput.getType().cast(); + mlir::Type inputElemTy = inputTy.getElementType(); + llvm::ArrayRef inputShape = inputTy.getShape(); + llvm::SmallVector inputSliceSizes( + {inputShape[0], inputShape[1] / group, inputShape[2], inputShape[3]}); + + mlir::RankedTensorType weightTy = + weight.getType().cast(); + mlir::Type weightElemTy = weightTy.getElementType(); + llvm::ArrayRef weightShape = weightTy.getShape(); + llvm::SmallVector weightSliceSizes( + {weightShape[0] / group, weightShape[1], weightShape[2], weightShape[3]}); + + mlir::RankedTensorType resultTy = + conv2dOp.getResult().getType().cast(); + llvm::ArrayRef resultShape = resultTy.getShape(); + llvm::SmallVector sliceResultSizes = { + resultShape[0], weightSliceSizes[0], resultShape[2], resultShape[3]}; + mlir::RankedTensorType sliceResultType = + mlir::RankedTensorType::get(sliceResultSizes, inputElemTy); + + // slice the input, weight, and output to apply different convolutions and + // store their outputs in a single result found in `finalResult` + mlir::Value finalResult = outputTensor; + for (int g = 0; g < group; g++) { + // input[:][g * (input_C / group) : (g + 1) * (input_C / group)][:][:] + mlir::Value inputSlice = extractContiguous4DSlice( + rewriter, conv2dOp.getLoc(), paddedInput, + mlir::RankedTensorType::get(inputSliceSizes, inputElemTy), + inputSliceSizes, {0, g * inputSliceSizes[1], 0, 0}); + // weight[g * (weight_F / group) : (g + 1) * (weight_F / group)][:][:][:] + mlir::Value weightSlice = extractContiguous4DSlice( + rewriter, conv2dOp.getLoc(), weight, + mlir::RankedTensorType::get(weightSliceSizes, weightElemTy), + weightSliceSizes, {g * weightSliceSizes[0], 0, 0, 0}); + // bias[:][g * (weight_F / group) : (g + 1) * (weight_F / group)][:][:] + mlir::Value biasSlice = extractContiguous4DSlice( + rewriter, conv2dOp.getLoc(), outputTensor, sliceResultType, + sliceResultSizes, {0, g * sliceResultSizes[1], 0, 0}); + // attributes for custom linalg named op + auto addOpAttr = rewriter.getNamedAttr( + "add", rewriter.getStringAttr( + mlir::concretelang::FHE::AddEintOp::getOperationName())); + auto mulOpAttr = rewriter.getNamedAttr( + "mul", rewriter.getStringAttr( + mlir::concretelang::FHE::MulEintIntOp::getOperationName())); + // apply conv + mlir::Value convResult = + rewriter + .create( + conv2dOp.getLoc(), sliceResultType, + mlir::ValueRange{inputSlice, weightSlice}, biasSlice, + stridesAttr, dilationsAttr, + llvm::ArrayRef({addOpAttr, mulOpAttr})) + .getResult(0); + // insert result of a single conv in the final result + finalResult = + rewriter + .create( + conv2dOp.getLoc(), convResult, finalResult, + llvm::SmallVector{ + rewriter.getI64IntegerAttr(0), + rewriter.getI64IntegerAttr(g * sliceResultSizes[1]), + rewriter.getI64IntegerAttr(0), + rewriter.getI64IntegerAttr(0), + }, + llvm::SmallVector{ + rewriter.getI64IntegerAttr(sliceResultSizes[0]), + rewriter.getI64IntegerAttr(sliceResultSizes[1]), + rewriter.getI64IntegerAttr(sliceResultSizes[2]), + rewriter.getI64IntegerAttr(sliceResultSizes[3]), + }, + llvm::SmallVector{ + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(1), + rewriter.getI64IntegerAttr(1), + }) + .getResult(); + } + + rewriter.replaceOp(conv2dOp, finalResult); + return mlir::success(); +} + /// This rewrite pattern transforms any instance of operators -/// `FHELinalg.conv2d` to an instance of `linalg.fhelinalg_conv_2d_nchw_fchw`. -/// The transformation consists of padding the input tensor, and initializing -/// the output tensor with bias values if any. +/// `FHELinalg.conv2d` to one or multiple instances of +/// `linalg.conv_2d_nchw_fchw`. The transformation consists of padding the input +/// tensor, and initializing the output tensor with bias values if any. Multiple +/// linalng conv operations can be generated, and their output concatenated in +/// the case of grouped convolution struct FHELinalgConv2dToLinalgConv2d : public ::mlir::OpRewritePattern { FHELinalgConv2dToLinalgConv2d(::mlir::MLIRContext *context) @@ -1537,6 +1667,7 @@ struct FHELinalgConv2dToLinalgConv2d mlir::concretelang::FHELinalg::getStridesFromConv2d(conv2dOp); mlir::SmallVector dilationsInts = mlir::concretelang::FHELinalg::getDilationsFromConv2d(conv2dOp); + int64_t group = mlir::concretelang::FHELinalg::getGroupFromConv2d(conv2dOp); // Pad the input tensor according to padding. mlir::SmallVector lowPaddingIncludingNC = {0, 0}; @@ -1602,18 +1733,29 @@ struct FHELinalgConv2dToLinalgConv2d auto stridesAttr = rewriter.getI64VectorAttr(stridesInts); auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts); - auto addOpAttr = rewriter.getNamedAttr( - "add", rewriter.getStringAttr( - mlir::concretelang::FHE::AddEintOp::getOperationName())); - auto mulOpAttr = rewriter.getNamedAttr( - "mul", rewriter.getStringAttr( - mlir::concretelang::FHE::MulEintIntOp::getOperationName())); - rewriter.replaceOpWithNewOp( - conv2dOp, biasInitTensor.getType(), - mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr, - dilationsAttr, - llvm::ArrayRef({addOpAttr, mulOpAttr})); - return mlir::success(); + + // we can directly use linalg::Conv2DNchwFchwOp if group is equal to 1, but + // since there is no support for groups in linalg conv operations, we need + // to slice the different tensors and apply multiple convolution in case + // group is greater than 1 + if (group == 1) { + auto addOpAttr = rewriter.getNamedAttr( + "add", rewriter.getStringAttr( + mlir::concretelang::FHE::AddEintOp::getOperationName())); + auto mulOpAttr = rewriter.getNamedAttr( + "mul", + rewriter.getStringAttr( + mlir::concretelang::FHE::MulEintIntOp::getOperationName())); + rewriter.replaceOpWithNewOp( + conv2dOp, biasInitTensor.getType(), + mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr, + dilationsAttr, + llvm::ArrayRef({addOpAttr, mulOpAttr})); + return mlir::success(); + } + return createGroupedConv2D(rewriter, conv2dOp, paddedInput, weight, + biasInitTensor, stridesAttr, dilationsAttr, + group); }; }; diff --git a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp index ee82d3216..5a8236569 100644 --- a/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp +++ b/compiler/lib/Dialect/FHELinalg/IR/FHELinalgOps.cpp @@ -808,6 +808,13 @@ getDilationsFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { return dilationsInts; } +int64_t getGroupFromConv2d(mlir::concretelang::FHELinalg::Conv2dOp &convOp) { + llvm::Optional optionalGroup = convOp.group(); + if (optionalGroup.hasValue()) + return optionalGroup.getValue(); + return 1; +} + /// Verify the Conv2d shapes, attributes, and expected output dimensions mlir::LogicalResult Conv2dOp::verify() { auto inputTy = @@ -920,6 +927,11 @@ mlir::LogicalResult Conv2dOp::verify() { } } } + int64_t group = getGroupFromConv2d(*this); + if (group < 1) { + this->emitOpError() << "group must be strictly positif, but got " << group; + return mlir::failure(); + } // Extracting dimensions int64_t inputN = inputShape[0], inputC = inputShape[1], @@ -960,10 +972,15 @@ mlir::LogicalResult Conv2dOp::verify() { << inputN << ") but got " << resultN; return mlir::failure(); } - if (inputC != weightC) { - this->emitOpError() << "expected number of channels in weight to be equal " - "to number of channels in input (" - << inputC << ") but got " << weightC; + if (weightC != inputC / group) { + this->emitOpError() + << "expected number of channels in weight to be equal to " + << inputC / group << " (input_channels / group) but got " << weightC; + return mlir::failure(); + } + if (weightF % group != 0) { + this->emitOpError() << "expected number of feature maps (" << weightF + << ") to be a multiple of group (" << group << ")"; return mlir::failure(); } if (weightF != resultC) { diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir index b368fd599..a2c9f99bc 100644 --- a/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir +++ b/compiler/tests/check_tests/Dialect/FHELinalg/ops.invalid.mlir @@ -213,6 +213,41 @@ func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x // ----- +func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op group must be strictly positif, but got 0}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){group = 0 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + + +func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected number of channels in weight to be equal to 1 (input_channels / group) but got 3}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){group = 3 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + + +func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x1x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected number of feature maps (4) to be a multiple of group (3)}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){group = 3 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x1x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + + +func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<6x1x14x14xi3>, %bias: tensor<6xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + // expected-error @+1 {{'FHELinalg.conv2d' op expected number of output channels to be equal to the number of filters (6) but got 4}} + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){group = 3 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<6x1x14x14xi3>, tensor<6xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} + +// ----- + func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { // expected-error @+1 {{'FHELinalg.conv2d' op expected height of output to be equal to 8 but got 15}} %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[2,2]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0 ,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> @@ -231,14 +266,6 @@ func.func @conv2d(%input: tensor<101x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x // ----- -func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x4x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { - // expected-error @+1 {{'FHELinalg.conv2d' op expected number of channels in weight to be equal to number of channels in input (3) but got 4}} - %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x4x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> - return %1 : tensor<100x4x15x15x!FHE.eint<2>> -} - -// ----- - func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi4>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { // expected-error @+1 {{'FHELinalg.conv2d' op expected weight element type to have width 3 but got 4}} diff --git a/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir b/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir index 9f3c8ce83..2f0ae3b21 100644 --- a/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir +++ b/compiler/tests/check_tests/Dialect/FHELinalg/ops.mlir @@ -435,11 +435,11 @@ func.func @dot_eint_int(%arg0: tensor<2x!FHE.eint<2>>, ///////////////////////////////////////////////// // CHECK: func.func @conv2d(%[[ARG0:.*]]: tensor<100x3x28x28x!FHE.eint<2>>, %[[ARG1:.*]]: tensor<4x3x14x14xi3>, %[[ARG2:.*]]: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { -// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) {dilations = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> +// CHECK-NEXT: %[[V0:.*]] = "FHELinalg.conv2d"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) {dilations = dense<1> : tensor<2xi64>, group = 1 : i64, padding = dense<0> : tensor<4xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> // CHECK-NEXT: return %[[V0]] : tensor<100x4x15x15x!FHE.eint<2>> // CHECK-NEXT: } func.func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { - %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0,0, 0, 0]> : tensor<4xi64>, group = 1 : i64}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> return %1 : tensor<100x4x15x15x!FHE.eint<2>> }