feat: lower FHELinalg.conv2d to linalg

This is currently lowering to our custom linalg conv operation, since
linalg ops doesn't support custom types. But as linalg will support
custom types in the future, we may want to lower to the native linalg op
instead
This commit is contained in:
youben11
2022-02-16 15:33:34 +01:00
committed by Ayoub Benaissa
parent 6d2f853c07
commit 86379096df
2 changed files with 171 additions and 0 deletions

View File

@@ -9,7 +9,9 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -1169,6 +1171,142 @@ struct ConcatRewritePattern
};
};
static mlir::SmallVector<mlir::OpFoldResult>
getAsOpFoldResult(mlir::OpBuilder &b, mlir::Location loc,
mlir::SmallVectorImpl<int64_t> &ints) {
return llvm::to_vector<4>(
llvm::map_range(ints, [&](int64_t val) -> mlir::OpFoldResult {
return b.getIndexAttr(val);
}));
}
// Helper function to get the padding tensor given the padding int values, and
// the value to pad with
static mlir::Value
getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input,
mlir::SmallVectorImpl<int64_t> &lowPaddingInts,
mlir::SmallVectorImpl<int64_t> &highPaddingInts,
mlir::Value pad) {
assert(input.getType().isa<mlir::RankedTensorType>() &&
"input must be RankedTensorType");
mlir::Location loc = op->getLoc();
mlir::Type rankedTensorType = mlir::linalg::PadTensorOp::inferResultType(
input.getType().cast<mlir::RankedTensorType>(), lowPaddingInts,
highPaddingInts);
mlir::SmallVector<mlir::OpFoldResult> lowPaddings =
getAsOpFoldResult(b, loc, lowPaddingInts);
mlir::SmallVector<mlir::OpFoldResult> highPaddings =
getAsOpFoldResult(b, loc, highPaddingInts);
mlir::Value paddedInput = mlir::linalg::PadTensorOp::createPadScalarOp(
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
/*packing=*/false, loc, b);
return paddedInput;
}
// This rewrite pattern transforms any instance of operators
// `FHELinalg.conv2d` to an instance of `linalg.fhelinalg_conv_2d_nchw_fchw`.
// The transformation consists of padding the input tensor, and initializing the
// output tensor with bias values if any.
struct FHELinalgConv2dToLinalgConv2d
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::Conv2dOp> {
FHELinalgConv2dToLinalgConv2d(::mlir::MLIRContext *context)
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Conv2dOp>(
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
::mlir::LogicalResult
matchAndRewrite(::mlir::concretelang::FHELinalg::Conv2dOp conv2dOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = conv2dOp->getLoc();
mlir::Value input =
conv2dOp.input(); /* shape: Batch*Channels*Height*Width */
mlir::Value weight =
conv2dOp.weight(); /* shape: Filters*Channels*Height*Width */
mlir::Type inputElementType =
input.getType().cast<mlir::RankedTensorType>().getElementType();
// Attriutes are assumed to be correct after passing the verification
mlir::SmallVector<int64_t, 4> paddingInts =
mlir::concretelang::FHELinalg::getPaddingFromConv2d(conv2dOp);
mlir::SmallVector<int64_t, 2> stridesInts =
mlir::concretelang::FHELinalg::getStridesFromConv2d(conv2dOp);
mlir::SmallVector<int64_t, 2> dilationsInts =
mlir::concretelang::FHELinalg::getDilationsFromConv2d(conv2dOp);
// Pad the input tensor according to padding.
mlir::SmallVector<int64_t, 4> lowPaddingIncludingNC = {0, 0};
lowPaddingIncludingNC.insert(lowPaddingIncludingNC.end(),
paddingInts.begin() + 2, paddingInts.end());
mlir::SmallVector<int64_t, 4> highPaddingIncludingNC = {0, 0};
highPaddingIncludingNC.insert(highPaddingIncludingNC.end(),
paddingInts.begin(), paddingInts.begin() + 2);
mlir::Value paddingValue =
rewriter.create<mlir::concretelang::FHE::ZeroEintOp>(
loc,
input.getType().cast<mlir::RankedTensorType>().getElementType());
mlir::Value paddedInput =
getPaddedTensor(conv2dOp, rewriter, input, lowPaddingIncludingNC,
highPaddingIncludingNC, paddingValue);
// TODO(Optimization): output tensor is being constructed in two different
// ways, depending of whether there is a bias or not:
// 1- There is no bias: we initialize the output tensor to encryptions of
// zero
// 2- There is a bias: we initialize the output tensor to encryptions of
// zeros, then we add bias values.
// For the second case, it can be done by initializing the output to
// encryption of bias values directly
mlir::Value initTensor =
rewriter.create<mlir::concretelang::FHE::ZeroTensorOp>(
loc, mlir::RankedTensorType::get(conv2dOp.getResult()
.getType()
.cast<mlir::RankedTensorType>()
.getShape(),
inputElementType));
// Since linalg doesn't support a bias in the conv operation, we initialize
// the output tensor to the bias values, so that conv results get
// accumulated to it
mlir::Value bias = conv2dOp.bias(); /* optional of shape: Filters */
mlir::Value biasInitTensor;
if (!bias) { // no bias was used
biasInitTensor = initTensor;
} else {
// Fill the output tensor with bias values
auto resultRank =
initTensor.getType().cast<mlir::RankedTensorType>().getRank();
mlir::SmallVector<mlir::AffineMap> indexingMaps = {
mlir::AffineMap::get(resultRank, 0, rewriter.getAffineDimExpr(1),
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(resultRank)};
mlir::SmallVector<llvm::StringRef> iteratorTypes(resultRank, "parallel");
biasInitTensor =
rewriter
.create<mlir::linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor, indexingMaps,
iteratorTypes,
[](mlir::OpBuilder &b, mlir::Location loc,
mlir::ValueRange args) {
mlir::Value encryptedBias =
b.create<mlir::concretelang::FHE::AddEintIntOp>(
loc, args[1], args[0])
.getResult();
b.create<mlir::linalg::YieldOp>(loc, encryptedBias);
})
.getResult(0);
}
auto stridesAttr = rewriter.getI64VectorAttr(stridesInts);
auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts);
rewriter.replaceOpWithNewOp<
mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>(
conv2dOp, biasInitTensor.getType(),
mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr,
dilationsAttr);
return mlir::success();
};
};
namespace {
struct FHETensorOpsToLinalg
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
@@ -1189,6 +1327,9 @@ void FHETensorOpsToLinalg::runOnFunction() {
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();
// TODO: this should be removed when we no longer need a custom generated op
// for conv that works on tensors of custom types
target.addLegalOp<mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>();
mlir::OwningRewritePatternList patterns(&getContext());
patterns.insert<DotToLinalgGeneric>(&getContext());
@@ -1229,6 +1370,7 @@ void FHETensorOpsToLinalg::runOnFunction() {
&getContext());
patterns.insert<SumToLinalgGeneric>(&getContext());
patterns.insert<ConcatRewritePattern>(&getContext());
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
if (mlir::applyPartialConversion(function, target, std::move(patterns))
.failed())

View File

@@ -0,0 +1,29 @@
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
//CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d1)>
//CHECK-NEXT: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
//CHECK-NEXT: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
//CHECK-NEXT: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
//CHECK-NEXT: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
//CHECK-NEXT: module {
//CHECK-NEXT: func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> {
//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4xi3>) outs(%0 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>): // no predecessors
//CHECK-NEXT: %3 = "TFHE.add_glwe_int"(%arg4, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %3 : !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, tensor<4x3x14x14xi3>) outs(%1 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
//CHECK-NEXT: ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !TFHE.glwe<{_,_,_}{2}>): // no predecessors
//CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: %4 = "TFHE.add_glwe"(%arg5, %3) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}>
//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: return %2 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
//CHECK-NEXT: }
//CHECK-NEXT: }
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
}