mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: lower FHELinalg.conv2d to linalg
This is currently lowering to our custom linalg conv operation, since linalg ops doesn't support custom types. But as linalg will support custom types in the future, we may want to lower to the native linalg op instead
This commit is contained in:
@@ -9,7 +9,9 @@
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
@@ -1169,6 +1171,142 @@ struct ConcatRewritePattern
|
||||
};
|
||||
};
|
||||
|
||||
static mlir::SmallVector<mlir::OpFoldResult>
|
||||
getAsOpFoldResult(mlir::OpBuilder &b, mlir::Location loc,
|
||||
mlir::SmallVectorImpl<int64_t> &ints) {
|
||||
return llvm::to_vector<4>(
|
||||
llvm::map_range(ints, [&](int64_t val) -> mlir::OpFoldResult {
|
||||
return b.getIndexAttr(val);
|
||||
}));
|
||||
}
|
||||
|
||||
// Helper function to get the padding tensor given the padding int values, and
|
||||
// the value to pad with
|
||||
static mlir::Value
|
||||
getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input,
|
||||
mlir::SmallVectorImpl<int64_t> &lowPaddingInts,
|
||||
mlir::SmallVectorImpl<int64_t> &highPaddingInts,
|
||||
mlir::Value pad) {
|
||||
assert(input.getType().isa<mlir::RankedTensorType>() &&
|
||||
"input must be RankedTensorType");
|
||||
mlir::Location loc = op->getLoc();
|
||||
mlir::Type rankedTensorType = mlir::linalg::PadTensorOp::inferResultType(
|
||||
input.getType().cast<mlir::RankedTensorType>(), lowPaddingInts,
|
||||
highPaddingInts);
|
||||
mlir::SmallVector<mlir::OpFoldResult> lowPaddings =
|
||||
getAsOpFoldResult(b, loc, lowPaddingInts);
|
||||
mlir::SmallVector<mlir::OpFoldResult> highPaddings =
|
||||
getAsOpFoldResult(b, loc, highPaddingInts);
|
||||
mlir::Value paddedInput = mlir::linalg::PadTensorOp::createPadScalarOp(
|
||||
rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings,
|
||||
/*packing=*/false, loc, b);
|
||||
return paddedInput;
|
||||
}
|
||||
|
||||
// This rewrite pattern transforms any instance of operators
|
||||
// `FHELinalg.conv2d` to an instance of `linalg.fhelinalg_conv_2d_nchw_fchw`.
|
||||
// The transformation consists of padding the input tensor, and initializing the
|
||||
// output tensor with bias values if any.
|
||||
struct FHELinalgConv2dToLinalgConv2d
|
||||
: public ::mlir::OpRewritePattern<mlir::concretelang::FHELinalg::Conv2dOp> {
|
||||
FHELinalgConv2dToLinalgConv2d(::mlir::MLIRContext *context)
|
||||
: ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Conv2dOp>(
|
||||
context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(::mlir::concretelang::FHELinalg::Conv2dOp conv2dOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
mlir::Location loc = conv2dOp->getLoc();
|
||||
mlir::Value input =
|
||||
conv2dOp.input(); /* shape: Batch*Channels*Height*Width */
|
||||
mlir::Value weight =
|
||||
conv2dOp.weight(); /* shape: Filters*Channels*Height*Width */
|
||||
|
||||
mlir::Type inputElementType =
|
||||
input.getType().cast<mlir::RankedTensorType>().getElementType();
|
||||
|
||||
// Attriutes are assumed to be correct after passing the verification
|
||||
mlir::SmallVector<int64_t, 4> paddingInts =
|
||||
mlir::concretelang::FHELinalg::getPaddingFromConv2d(conv2dOp);
|
||||
mlir::SmallVector<int64_t, 2> stridesInts =
|
||||
mlir::concretelang::FHELinalg::getStridesFromConv2d(conv2dOp);
|
||||
mlir::SmallVector<int64_t, 2> dilationsInts =
|
||||
mlir::concretelang::FHELinalg::getDilationsFromConv2d(conv2dOp);
|
||||
|
||||
// Pad the input tensor according to padding.
|
||||
mlir::SmallVector<int64_t, 4> lowPaddingIncludingNC = {0, 0};
|
||||
lowPaddingIncludingNC.insert(lowPaddingIncludingNC.end(),
|
||||
paddingInts.begin() + 2, paddingInts.end());
|
||||
mlir::SmallVector<int64_t, 4> highPaddingIncludingNC = {0, 0};
|
||||
highPaddingIncludingNC.insert(highPaddingIncludingNC.end(),
|
||||
paddingInts.begin(), paddingInts.begin() + 2);
|
||||
mlir::Value paddingValue =
|
||||
rewriter.create<mlir::concretelang::FHE::ZeroEintOp>(
|
||||
loc,
|
||||
input.getType().cast<mlir::RankedTensorType>().getElementType());
|
||||
mlir::Value paddedInput =
|
||||
getPaddedTensor(conv2dOp, rewriter, input, lowPaddingIncludingNC,
|
||||
highPaddingIncludingNC, paddingValue);
|
||||
|
||||
// TODO(Optimization): output tensor is being constructed in two different
|
||||
// ways, depending of whether there is a bias or not:
|
||||
// 1- There is no bias: we initialize the output tensor to encryptions of
|
||||
// zero
|
||||
// 2- There is a bias: we initialize the output tensor to encryptions of
|
||||
// zeros, then we add bias values.
|
||||
// For the second case, it can be done by initializing the output to
|
||||
// encryption of bias values directly
|
||||
mlir::Value initTensor =
|
||||
rewriter.create<mlir::concretelang::FHE::ZeroTensorOp>(
|
||||
loc, mlir::RankedTensorType::get(conv2dOp.getResult()
|
||||
.getType()
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getShape(),
|
||||
inputElementType));
|
||||
// Since linalg doesn't support a bias in the conv operation, we initialize
|
||||
// the output tensor to the bias values, so that conv results get
|
||||
// accumulated to it
|
||||
mlir::Value bias = conv2dOp.bias(); /* optional of shape: Filters */
|
||||
mlir::Value biasInitTensor;
|
||||
if (!bias) { // no bias was used
|
||||
biasInitTensor = initTensor;
|
||||
} else {
|
||||
// Fill the output tensor with bias values
|
||||
auto resultRank =
|
||||
initTensor.getType().cast<mlir::RankedTensorType>().getRank();
|
||||
mlir::SmallVector<mlir::AffineMap> indexingMaps = {
|
||||
mlir::AffineMap::get(resultRank, 0, rewriter.getAffineDimExpr(1),
|
||||
rewriter.getContext()),
|
||||
rewriter.getMultiDimIdentityMap(resultRank)};
|
||||
mlir::SmallVector<llvm::StringRef> iteratorTypes(resultRank, "parallel");
|
||||
biasInitTensor =
|
||||
rewriter
|
||||
.create<mlir::linalg::GenericOp>(
|
||||
loc, initTensor.getType(), bias, initTensor, indexingMaps,
|
||||
iteratorTypes,
|
||||
[](mlir::OpBuilder &b, mlir::Location loc,
|
||||
mlir::ValueRange args) {
|
||||
mlir::Value encryptedBias =
|
||||
b.create<mlir::concretelang::FHE::AddEintIntOp>(
|
||||
loc, args[1], args[0])
|
||||
.getResult();
|
||||
b.create<mlir::linalg::YieldOp>(loc, encryptedBias);
|
||||
})
|
||||
.getResult(0);
|
||||
}
|
||||
|
||||
auto stridesAttr = rewriter.getI64VectorAttr(stridesInts);
|
||||
auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts);
|
||||
rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>(
|
||||
conv2dOp, biasInitTensor.getType(),
|
||||
mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr,
|
||||
dilationsAttr);
|
||||
return mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct FHETensorOpsToLinalg
|
||||
: public FHETensorOpsToLinalgBase<FHETensorOpsToLinalg> {
|
||||
@@ -1189,6 +1327,9 @@ void FHETensorOpsToLinalg::runOnFunction() {
|
||||
target.addLegalDialect<mlir::arith::ArithmeticDialect>();
|
||||
target.addIllegalOp<mlir::concretelang::FHELinalg::Dot>();
|
||||
target.addIllegalDialect<mlir::concretelang::FHELinalg::FHELinalgDialect>();
|
||||
// TODO: this should be removed when we no longer need a custom generated op
|
||||
// for conv that works on tensors of custom types
|
||||
target.addLegalOp<mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>();
|
||||
|
||||
mlir::OwningRewritePatternList patterns(&getContext());
|
||||
patterns.insert<DotToLinalgGeneric>(&getContext());
|
||||
@@ -1229,6 +1370,7 @@ void FHETensorOpsToLinalg::runOnFunction() {
|
||||
&getContext());
|
||||
patterns.insert<SumToLinalgGeneric>(&getContext());
|
||||
patterns.insert<ConcatRewritePattern>(&getContext());
|
||||
patterns.insert<FHELinalgConv2dToLinalgConv2d>(&getContext());
|
||||
|
||||
if (mlir::applyPartialConversion(function, target, std::move(patterns))
|
||||
.failed())
|
||||
|
||||
29
compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir
Normal file
29
compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir
Normal file
@@ -0,0 +1,29 @@
|
||||
// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d1)>
|
||||
//CHECK-NEXT: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
//CHECK-NEXT: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
|
||||
//CHECK-NEXT: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
|
||||
//CHECK-NEXT: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
|
||||
//CHECK-NEXT: module {
|
||||
//CHECK-NEXT: func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> {
|
||||
//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4xi3>) outs(%0 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: ^bb0(%arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
//CHECK-NEXT: %3 = "TFHE.add_glwe_int"(%arg4, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: linalg.yield %3 : !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, tensor<4x3x14x14xi3>) outs(%1 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
//CHECK-NEXT: ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !TFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
//CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: %4 = "TFHE.add_glwe"(%arg5, %3) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}>
|
||||
//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: return %2 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>
|
||||
//CHECK-NEXT: }
|
||||
//CHECK-NEXT: }
|
||||
|
||||
func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> {
|
||||
%1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>>
|
||||
return %1 : tensor<100x4x15x15x!FHE.eint<2>>
|
||||
}
|
||||
Reference in New Issue
Block a user