diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index 710f6751f..9addc4c13 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -9,7 +9,9 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -1169,6 +1171,142 @@ struct ConcatRewritePattern }; }; +static mlir::SmallVector +getAsOpFoldResult(mlir::OpBuilder &b, mlir::Location loc, + mlir::SmallVectorImpl &ints) { + return llvm::to_vector<4>( + llvm::map_range(ints, [&](int64_t val) -> mlir::OpFoldResult { + return b.getIndexAttr(val); + })); +} + +// Helper function to get the padding tensor given the padding int values, and +// the value to pad with +static mlir::Value +getPaddedTensor(mlir::Operation *op, mlir::OpBuilder &b, mlir::Value &input, + mlir::SmallVectorImpl &lowPaddingInts, + mlir::SmallVectorImpl &highPaddingInts, + mlir::Value pad) { + assert(input.getType().isa() && + "input must be RankedTensorType"); + mlir::Location loc = op->getLoc(); + mlir::Type rankedTensorType = mlir::linalg::PadTensorOp::inferResultType( + input.getType().cast(), lowPaddingInts, + highPaddingInts); + mlir::SmallVector lowPaddings = + getAsOpFoldResult(b, loc, lowPaddingInts); + mlir::SmallVector highPaddings = + getAsOpFoldResult(b, loc, highPaddingInts); + mlir::Value paddedInput = mlir::linalg::PadTensorOp::createPadScalarOp( + rankedTensorType, input, pad, /*low=*/lowPaddings, /*high=*/highPaddings, + /*packing=*/false, loc, b); + return paddedInput; +} + +// This rewrite pattern transforms any instance of operators +// `FHELinalg.conv2d` to an instance of `linalg.fhelinalg_conv_2d_nchw_fchw`. +// The transformation consists of padding the input tensor, and initializing the +// output tensor with bias values if any. +struct FHELinalgConv2dToLinalgConv2d + : public ::mlir::OpRewritePattern { + FHELinalgConv2dToLinalgConv2d(::mlir::MLIRContext *context) + : ::mlir::OpRewritePattern<::mlir::concretelang::FHELinalg::Conv2dOp>( + context, mlir::concretelang::DEFAULT_PATTERN_BENEFIT) {} + + ::mlir::LogicalResult + matchAndRewrite(::mlir::concretelang::FHELinalg::Conv2dOp conv2dOp, + ::mlir::PatternRewriter &rewriter) const override { + + mlir::Location loc = conv2dOp->getLoc(); + mlir::Value input = + conv2dOp.input(); /* shape: Batch*Channels*Height*Width */ + mlir::Value weight = + conv2dOp.weight(); /* shape: Filters*Channels*Height*Width */ + + mlir::Type inputElementType = + input.getType().cast().getElementType(); + + // Attriutes are assumed to be correct after passing the verification + mlir::SmallVector paddingInts = + mlir::concretelang::FHELinalg::getPaddingFromConv2d(conv2dOp); + mlir::SmallVector stridesInts = + mlir::concretelang::FHELinalg::getStridesFromConv2d(conv2dOp); + mlir::SmallVector dilationsInts = + mlir::concretelang::FHELinalg::getDilationsFromConv2d(conv2dOp); + + // Pad the input tensor according to padding. + mlir::SmallVector lowPaddingIncludingNC = {0, 0}; + lowPaddingIncludingNC.insert(lowPaddingIncludingNC.end(), + paddingInts.begin() + 2, paddingInts.end()); + mlir::SmallVector highPaddingIncludingNC = {0, 0}; + highPaddingIncludingNC.insert(highPaddingIncludingNC.end(), + paddingInts.begin(), paddingInts.begin() + 2); + mlir::Value paddingValue = + rewriter.create( + loc, + input.getType().cast().getElementType()); + mlir::Value paddedInput = + getPaddedTensor(conv2dOp, rewriter, input, lowPaddingIncludingNC, + highPaddingIncludingNC, paddingValue); + + // TODO(Optimization): output tensor is being constructed in two different + // ways, depending of whether there is a bias or not: + // 1- There is no bias: we initialize the output tensor to encryptions of + // zero + // 2- There is a bias: we initialize the output tensor to encryptions of + // zeros, then we add bias values. + // For the second case, it can be done by initializing the output to + // encryption of bias values directly + mlir::Value initTensor = + rewriter.create( + loc, mlir::RankedTensorType::get(conv2dOp.getResult() + .getType() + .cast() + .getShape(), + inputElementType)); + // Since linalg doesn't support a bias in the conv operation, we initialize + // the output tensor to the bias values, so that conv results get + // accumulated to it + mlir::Value bias = conv2dOp.bias(); /* optional of shape: Filters */ + mlir::Value biasInitTensor; + if (!bias) { // no bias was used + biasInitTensor = initTensor; + } else { + // Fill the output tensor with bias values + auto resultRank = + initTensor.getType().cast().getRank(); + mlir::SmallVector indexingMaps = { + mlir::AffineMap::get(resultRank, 0, rewriter.getAffineDimExpr(1), + rewriter.getContext()), + rewriter.getMultiDimIdentityMap(resultRank)}; + mlir::SmallVector iteratorTypes(resultRank, "parallel"); + biasInitTensor = + rewriter + .create( + loc, initTensor.getType(), bias, initTensor, indexingMaps, + iteratorTypes, + [](mlir::OpBuilder &b, mlir::Location loc, + mlir::ValueRange args) { + mlir::Value encryptedBias = + b.create( + loc, args[1], args[0]) + .getResult(); + b.create(loc, encryptedBias); + }) + .getResult(0); + } + + auto stridesAttr = rewriter.getI64VectorAttr(stridesInts); + auto dilationsAttr = rewriter.getI64VectorAttr(dilationsInts); + rewriter.replaceOpWithNewOp< + mlir::concretelang::FHELinalg::FhelinalgConv2DNchwFchwOp>( + conv2dOp, biasInitTensor.getType(), + mlir::ValueRange{paddedInput, weight}, biasInitTensor, stridesAttr, + dilationsAttr); + return mlir::success(); + }; +}; + namespace { struct FHETensorOpsToLinalg : public FHETensorOpsToLinalgBase { @@ -1189,6 +1327,9 @@ void FHETensorOpsToLinalg::runOnFunction() { target.addLegalDialect(); target.addIllegalOp(); target.addIllegalDialect(); + // TODO: this should be removed when we no longer need a custom generated op + // for conv that works on tensors of custom types + target.addLegalOp(); mlir::OwningRewritePatternList patterns(&getContext()); patterns.insert(&getContext()); @@ -1229,6 +1370,7 @@ void FHETensorOpsToLinalg::runOnFunction() { &getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); + patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) .failed()) diff --git a/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir b/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir new file mode 100644 index 000000000..7d0e23113 --- /dev/null +++ b/compiler/tests/Conversion/FHEToTFHE/FHEToTFHE/conv2d.mlir @@ -0,0 +1,29 @@ +// RUN: concretecompiler %s --action=dump-tfhe 2>&1| FileCheck %s + +//CHECK: #map0 = affine_map<(d0, d1, d2, d3) -> (d1)> +//CHECK-NEXT: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +//CHECK-NEXT: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)> +//CHECK-NEXT: #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)> +//CHECK-NEXT: #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> +//CHECK-NEXT: module { +//CHECK-NEXT: func @conv2d(%arg0: tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, %arg1: tensor<4x3x14x14xi3>, %arg2: tensor<4xi3>) -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> { +//CHECK-NEXT: %0 = "TFHE.zero_tensor"() : () -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<4xi3>) outs(%0 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg3: i3, %arg4: !TFHE.glwe<{_,_,_}{2}>): // no predecessors +//CHECK-NEXT: %3 = "TFHE.add_glwe_int"(%arg4, %arg3) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %3 : !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %2 = linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<100x3x28x28x!TFHE.glwe<{_,_,_}{2}>>, tensor<4x3x14x14xi3>) outs(%1 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %arg4: i3, %arg5: !TFHE.glwe<{_,_,_}{2}>): // no predecessors +//CHECK-NEXT: %3 = "TFHE.mul_glwe_int"(%arg3, %arg4) : (!TFHE.glwe<{_,_,_}{2}>, i3) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: %4 = "TFHE.add_glwe"(%arg5, %3) : (!TFHE.glwe<{_,_,_}{2}>, !TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %4 : !TFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: return %2 : tensor<100x4x15x15x!TFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: } + +func @conv2d(%input: tensor<100x3x28x28x!FHE.eint<2>>, %weight: tensor<4x3x14x14xi3>, %bias: tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> { + %1 = "FHELinalg.conv2d"(%input, %weight, %bias){strides = dense<[1,1]> : tensor<2xi64>, dilations = dense<[1,1]> : tensor<2xi64>, padding = dense<[0, 0, 0, 0]> : tensor<4xi64>}: (tensor<100x3x28x28x!FHE.eint<2>>, tensor<4x3x14x14xi3>, tensor<4xi3>) -> tensor<100x4x15x15x!FHE.eint<2>> + return %1 : tensor<100x4x15x15x!FHE.eint<2>> +} \ No newline at end of file