From 86379096df6a03e8cec82227a8a046cf3bfebdd6 Mon Sep 17 00:00:00 2001 From: youben11 Date: Wed, 16 Feb 2022 15:33:34 +0100 Subject: [PATCH] feat: lower FHELinalg.conv2d to linalg This is currently lowering to our custom linalg conv operation, since linalg ops doesn't support custom types. But as linalg will support custom types in the future, we may want to lower to the native linalg op instead --- .../TensorOpsToLinalg.cpp | 142 ++++++++++++++++++ .../FHEToTFHE/FHEToTFHE/conv2d.mlir | 29 ++++ 2 files changed, 171 insertions(+) create mode 100644 compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 710f6751f..9addc4c13 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -9,7 +9,9 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -1169,6 +1171,142 @@ struct ConcatRewritePattern }; }; +static mlir::SmallVector +getAsOpFoldResult(mlir::OpBuilder &b, mlir::Location loc, + mlir::SmallVectorImpl &ints) { + return llvm::to_vector<4>( + llvm::map_range(ints, [&](int64_t val) -> mlir::OpFoldResult { + return b.getIndexAttr(val); + })); +} + +// Helper function to get the padding tensor given the padding int values, and +// the value to pad with +static mlir::Value +getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input, + mlir::SmallVectorImpl &lowPaddingInts, + mlir::SmallVectorImpl &highPaddingInts, + mlir::Value pad) { + assert(input.getType().isa() && + "input must be RankedTensorType"); + mlir::Location loc = op->getLoc(); + mlir::Type rankedTensorType = mlir::linalg::PadTensorOp::inferResultType( + input.getType().cast(), lowPaddingInts, + highPaddingInts); + mlir::SmallVector lowPaddings = + getAsOpFoldResult(b, loc, lowPaddingInts); + mlir::SmallVector highPaddings = + getAsOpFoldResult(b, loc, highPaddingInts); + mlir::Value paddedInput = mlir::linalg::PadTensorOp::createPadScalarOp( + rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings, + /*packing=*/false, loc, b); + return paddedInput; +} + +// This rewrite pattern transforms any instance of operators +// `FHELinalg.conv2d` to an instance of `linalg.fhelinalg_conv_2d_nchw_fchw`. +// The transformation consists of padding the input tensor, and initializing the +// output tensor with bias values if any. +struct FHELinalgConv2dToLinalgConv2d + : public ::mlir::OpRewritePattern { + FHELinalgConv2dToLinalgConv2d(::mlir::MLIRContext *context) + : ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Conv2dOp>( + context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(::mlir::concretelang::FHELinalg::Conv2dOp conv2dOp, + ::mlir::PatternRewriter &rewriter) const override { + + mlir::Location loc = conv2dOp->getLoc(); + mlir::Value input = + conv2dOp.input(); /* shape: Batch*Channels*Height*Width */ + mlir::Value weight = + conv2dOp.weight(); /* shape: Filters*Channels*Height*Width */ + + mlir::Type inputElementType = + input.getType().cast().getElementType(); + + // Attriutes are assumed to be correct after passing the verification + mlir::SmallVector paddingInts = + mlir::concretelang::FHELinalg::getPaddingFromConv2d(conv2dOp); + mlir::SmallVector stridesInts = + mlir::concretelang::FHELinalg::getStridesFromConv2d(conv2dOp); + mlir::SmallVector dilationsInts = + mlir::concretelang::FHELinalg::getDilationsFromConv2d(conv2dOp); + + // Pad the input tensor according to padding. + mlir::SmallVector lowPaddingIncludingNC = {0, 0}; + lowPaddingIncludingNC.insert(lowPaddingIncludingNC.end(), + paddingInts.begin() + 2, paddingInts.end()); + mlir::SmallVector highPaddingIncludingNC = {0, 0}; + highPaddingIncludingNC.insert(highPaddingIncludingNC.end(), + paddingInts.begin(), paddingInts.begin() + 2); + mlir::Value paddingValue = + rewriter.create( + loc, + input.getType().cast().getElementType()); + mlir::Value paddedInput = + getPaddedTensor(conv2dOp, rewriter, input, lowPaddingIncludingNC, + highPaddingIncludingNC, paddingValue); + + // TODO(Optimization): output tensor is being constructed in two different + // ways, depending of whether there is a bias or not: + // 1- There is no bias: we initialize the output tensor to encryptions of + // zero + // 2- There is a bias: we initialize the output tensor to encryptions of + // zeros, then we add bias values. + // For the second case, it can be done by initializing the output to + // encryption of bias values directly + mlir::Value initTensor = + rewriter.create( + loc, mlir::RankedTensorType::get(conv2dOp.getResult() + .getType() + .cast() + .getShape(), + inputElementType)); + // Since linalg doesn't support a bias in the conv operation, we initialize + // the output tensor to the bias values, so that conv results get + // accumulated to it + mlir::Value bias = conv2dOp.bias(); /* optional of shape: Filters */ + mlir::Value biasInitTensor; + if (!bias) { // no bias was used + biasInitTensor = initTensor; + } else { + // Fill the output tensor with bias values + auto resultRank = + initTensor.getType().cast().getRank(); + mlir::SmallVector indexingMaps = { + mlir::AffineMap::get(resultRank, 0, rewriter.getAffineDimExpr(1), + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(resultRank)}; + mlir::SmallVector iteratorTypes(resultRank, "parallel"); + biasInitTensor = + rewriter + .create( + loc, initTensor.getType(), bias, initTensor, indexingMaps, + iteratorTypes, + [](mlir::OpBuilder &b, mlir::Location loc, + mlir::ValueRange args) { + mlir::Value encryptedBias = + b.create( + loc, args[1], args[0]) + .getResult(); + b.create(loc, encryptedBias); + }) + .getResult(0); + } + + auto stridesAttr = rewriter.getI64VectorAttr(stridesInts); + auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts); + rewriter.replaceOpWithNewOp< + mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>( + conv2dOp, biasInitTensor.getType(), + mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr, + dilationsAttr); + return mlir::success(); + }; +}; + namespace { struct FHETensorOpsToLinalg : public FHETensorOpsToLinalgBase { @@ -1189,6 +1327,9 @@ void FHETensorOpsToLinalg::runOnFunction() { target.addLegalDialect(); target.addIllegalOp(); target.addIllegalDialect(); + // TODO: this should be removed when we no longer need a custom generated op + // for conv that works on tensors of custom types + target.addLegalOp(); mlir::OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); @@ -1229,6 +1370,7 @@ void FHETensorOpsToLinalg::runOnFunction() { &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir b/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir new file mode 100644 index 000000000..7d0e23113 --- /dev/null +++ b/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir @@ -0,0 +1,29 @@ +// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s + +//CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d1)> +//CHECK-NEXT: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +//CHECK-NEXT: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> +//CHECK-NEXT: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +//CHECK-NEXT: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +//CHECK-NEXT: module { +//CHECK-NEXT: func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> { +//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4xi3>) outs(%0 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>): // no predecessors +//CHECK-NEXT: %3 = "TFHE.add_glwe_int"(%arg4, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %3 : !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, tensor<4x3x14x14xi3>) outs(%1 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !TFHE.glwe<{_,_,_}{2}>): // no predecessors +//CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %4 = "TFHE.add_glwe"(%arg5, %3) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: return %2 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: } + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} \ No newline at end of file